mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Streaming] Streaming Cross-Lang API (#7464)
This commit is contained in:
parent
101255f782
commit
91f630f709
72 changed files with 1612 additions and 408 deletions
|
@ -542,6 +542,7 @@ def init(address=None,
|
|||
raylet_socket_name=None,
|
||||
temp_dir=None,
|
||||
load_code_from_local=False,
|
||||
java_worker_options=None,
|
||||
use_pickle=True,
|
||||
_internal_config=None,
|
||||
lru_evict=False):
|
||||
|
@ -651,6 +652,7 @@ def init(address=None,
|
|||
conventional location, e.g., "/tmp/ray".
|
||||
load_code_from_local: Whether code should be loaded from a local
|
||||
module or from the GCS.
|
||||
java_worker_options: Overwrite the options to start Java workers.
|
||||
use_pickle: Deprecated.
|
||||
_internal_config (str): JSON configuration for overriding
|
||||
RayConfig defaults. For testing purposes ONLY.
|
||||
|
@ -758,6 +760,7 @@ def init(address=None,
|
|||
raylet_socket_name=raylet_socket_name,
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local,
|
||||
java_worker_options=java_worker_options,
|
||||
_internal_config=_internal_config,
|
||||
)
|
||||
# Start the Ray processes. We set shutdown_at_exit=False because we
|
||||
|
@ -808,6 +811,9 @@ def init(address=None,
|
|||
if raylet_socket_name is not None:
|
||||
raise ValueError("When connecting to an existing cluster, "
|
||||
"raylet_socket_name must not be provided.")
|
||||
if java_worker_options is not None:
|
||||
raise ValueError("When connecting to an existing cluster, "
|
||||
"java_worker_options must not be provided.")
|
||||
if _internal_config is not None and len(_internal_config) != 0:
|
||||
raise ValueError("When connecting to an existing cluster, "
|
||||
"_internal_config must not be provided.")
|
||||
|
|
|
@ -39,6 +39,7 @@ define_java_module(
|
|||
":io_ray_ray_streaming-state",
|
||||
":io_ray_ray_streaming-api",
|
||||
"@ray_streaming_maven//:com_google_guava_guava",
|
||||
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||
"@ray_streaming_maven//:org_testng_testng",
|
||||
|
@ -46,7 +47,12 @@ define_java_module(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":io_ray_ray_streaming-state",
|
||||
"//java:io_ray_ray_api",
|
||||
"//java:io_ray_ray_runtime",
|
||||
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||
"@ray_streaming_maven//:com_google_code_gson_gson",
|
||||
"@ray_streaming_maven//:com_google_guava_guava",
|
||||
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||
],
|
||||
|
@ -129,8 +135,9 @@ define_java_module(
|
|||
":io_ray_ray_streaming-api",
|
||||
":io_ray_ray_streaming-runtime",
|
||||
"@ray_streaming_maven//:com_google_guava_guava",
|
||||
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
||||
"@ray_streaming_maven//:org_msgpack_msgpack_core",
|
||||
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||
|
@ -146,10 +153,12 @@ define_java_module(
|
|||
"//java:io_ray_ray_api",
|
||||
"//java:io_ray_ray_runtime",
|
||||
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
|
||||
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||
"@ray_streaming_maven//:com_google_guava_guava",
|
||||
"@ray_streaming_maven//:com_google_protobuf_protobuf_java",
|
||||
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
||||
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
||||
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||
"@ray_streaming_maven//:org_msgpack_msgpack_core",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||
|
|
|
@ -6,8 +6,11 @@ def gen_streaming_java_deps():
|
|||
artifacts = [
|
||||
"com.beust:jcommander:1.72",
|
||||
"com.google.guava:guava:27.0.1-jre",
|
||||
"com.google.code.findbugs:jsr305:3.0.2",
|
||||
"com.google.code.gson:gson:2.8.5",
|
||||
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
|
||||
"com.google.protobuf:protobuf-java:3.8.0",
|
||||
"org.apache.commons:commons-lang3:3.4",
|
||||
"de.ruedigermoeller:fst:2.57",
|
||||
"org.aeonbits.owner:owner:1.0.10",
|
||||
"org.slf4j:slf4j-api:1.7.12",
|
||||
|
@ -19,10 +22,9 @@ def gen_streaming_java_deps():
|
|||
"org.apache.commons:commons-lang3:3.3.2",
|
||||
"org.msgpack:msgpack-core:0.8.20",
|
||||
"org.testng:testng:6.9.10",
|
||||
"org.mockito:mockito-all:1.10.19",
|
||||
"org.powermock:powermock-module-testng:1.6.6",
|
||||
"org.powermock:powermock-api-mockito:1.6.6",
|
||||
"org.projectlombok:lombok:1.16.20",
|
||||
"org.mockito:mockito-all:1.10.19",
|
||||
"org.powermock:powermock-module-testng:1.6.6",
|
||||
"org.powermock:powermock-api-mockito:1.6.6",
|
||||
],
|
||||
repositories = [
|
||||
"https://repo1.maven.org/maven2/",
|
||||
|
|
|
@ -22,16 +22,36 @@
|
|||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>streaming-state</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
<version>3.0.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>27.0.1-jre</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
|
|
|
@ -22,6 +22,11 @@
|
|||
<artifactId>ray-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.ray</groupId>
|
||||
<artifactId>ray-runtime</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>streaming-state</artifactId>
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
package io.ray.streaming.api.context;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.gson.Gson;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.runtime.util.NetworkUtil;
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
class ClusterStarter {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
|
||||
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
|
||||
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
|
||||
|
||||
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
|
||||
Preconditions.checkArgument(Ray.internal() == null);
|
||||
RayConfig.reset();
|
||||
if (!isLocal) {
|
||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
System.setProperty("ray.run-mode", "CLUSTER");
|
||||
} else {
|
||||
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
|
||||
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
|
||||
}
|
||||
|
||||
if (!isCrossLanguage) {
|
||||
Ray.init();
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
File file = new File(socket);
|
||||
if (file.exists()) {
|
||||
LOG.info("Delete existing socket file {}", file);
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
|
||||
|
||||
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
|
||||
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
|
||||
// with `test` included in the name, we can check classpath `test` to filter out test classes.
|
||||
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
|
||||
.filter(s -> !s.contains(" ") && s.contains("test"))
|
||||
.collect(Collectors.joining(":"));
|
||||
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
|
||||
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
|
||||
config.put("num_workers_per_process_java", "1");
|
||||
// Start ray cluster.
|
||||
List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--redis-port=6379",
|
||||
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||
String.format("--node-manager-port=%s", nodeManagerPort),
|
||||
"--load-code-from-local",
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions,
|
||||
"--internal-config=" + new Gson().toJson(config)
|
||||
);
|
||||
if (!executeCommand(startCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
// Connect to the cluster.
|
||||
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
public static synchronized void stopCluster(boolean isCrossLanguage) {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
System.clearProperty("ray.redis.address");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
System.clearProperty("ray.raylet.node-manager-port");
|
||||
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
|
||||
System.clearProperty("ray.run-mode");
|
||||
|
||||
if (isCrossLanguage) {
|
||||
// Stop ray cluster.
|
||||
final List<String> stopCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"stop"
|
||||
);
|
||||
if (!executeCommand(stopCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't stop ray cluster");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute an external command.
|
||||
*
|
||||
* @return Whether the command succeeded.
|
||||
*/
|
||||
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
|
||||
LOG.info("Executing command: {}", String.join(" ", command));
|
||||
try {
|
||||
ProcessBuilder processBuilder = new ProcessBuilder(command)
|
||||
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
|
||||
.redirectError(ProcessBuilder.Redirect.INHERIT);
|
||||
Process process = processBuilder.start();
|
||||
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||
if (!exit) {
|
||||
process.destroyForcibly();
|
||||
}
|
||||
return process.exitValue() == 0;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,10 +1,12 @@
|
|||
package io.ray.streaming.api.context;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import io.ray.streaming.schedule.JobScheduler;
|
||||
import io.ray.streaming.util.Config;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
@ -13,11 +15,14 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.ServiceLoader;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Encapsulate the context information of a streaming Job.
|
||||
*/
|
||||
public class StreamingContext implements Serializable {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
|
||||
|
||||
private transient AtomicInteger idGenerator;
|
||||
|
||||
|
@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
|
|||
this.jobGraph = jobGraphBuilder.build();
|
||||
jobGraph.printJobGraph();
|
||||
|
||||
if (Ray.internal() == null) {
|
||||
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
|
||||
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
|
||||
ClusterStarter.startCluster(false, true);
|
||||
LOG.info("Created local cluster for job {}.", jobName);
|
||||
} else {
|
||||
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
|
||||
LOG.info("Created multi process cluster for job {}.", jobName);
|
||||
}
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
|
||||
} else {
|
||||
LOG.info("Reuse existing cluster.");
|
||||
}
|
||||
|
||||
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
|
||||
Iterator<JobScheduler> iterator = serviceLoader.iterator();
|
||||
Preconditions.checkArgument(iterator.hasNext(),
|
||||
|
@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
|
|||
public void withConfig(Map<String, String> jobConfig) {
|
||||
this.jobConfig = jobConfig;
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
if (Ray.internal() != null) {
|
||||
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package io.ray.streaming.api.stream;
|
||||
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.FilterFunction;
|
||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||
|
@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
|
|||
import io.ray.streaming.operator.impl.KeyByOperator;
|
||||
import io.ray.streaming.operator.impl.MapOperator;
|
||||
import io.ray.streaming.operator.impl.SinkOperator;
|
||||
import io.ray.streaming.python.stream.PythonDataStream;
|
||||
|
||||
/**
|
||||
* Represents a stream of data.
|
||||
*
|
||||
* This class defines all the streaming operations.
|
||||
* <p>This class defines all the streaming operations.
|
||||
*
|
||||
* @param <T> Type of data in the stream.
|
||||
*/
|
||||
public class DataStream<T> extends Stream<T> {
|
||||
public class DataStream<T> extends Stream<DataStream<T>, T> {
|
||||
|
||||
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||
super(streamingContext, streamOperator);
|
||||
}
|
||||
|
||||
public DataStream(DataStream input, StreamOperator streamOperator) {
|
||||
public DataStream(StreamingContext streamingContext,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
super(streamingContext, streamOperator, partition);
|
||||
}
|
||||
|
||||
public <R> DataStream(DataStream<R> input, StreamOperator streamOperator) {
|
||||
super(input, streamOperator);
|
||||
}
|
||||
|
||||
public <R> DataStream(DataStream<R> input,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
super(input, streamOperator, partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a java stream that reference passed python stream.
|
||||
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||
*/
|
||||
public DataStream(PythonDataStream referencedStream) {
|
||||
super(referencedStream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a map function to this stream.
|
||||
*
|
||||
|
@ -41,7 +62,7 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return A new DataStream.
|
||||
*/
|
||||
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
|
||||
return new DataStream<>(this, new MapOperator(mapFunction));
|
||||
return new DataStream<>(this, new MapOperator<>(mapFunction));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -52,11 +73,11 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return A new DataStream
|
||||
*/
|
||||
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
|
||||
return new DataStream(this, new FlatMapOperator(flatMapFunction));
|
||||
return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
|
||||
}
|
||||
|
||||
public DataStream<T> filter(FilterFunction<T> filterFunction) {
|
||||
return new DataStream<T>(this, new FilterOperator(filterFunction));
|
||||
return new DataStream<>(this, new FilterOperator<>(filterFunction));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -66,7 +87,7 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return A new UnionStream.
|
||||
*/
|
||||
public UnionStream<T> union(DataStream<T> other) {
|
||||
return new UnionStream(this, null, other);
|
||||
return new UnionStream<>(this, null, other);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -93,7 +114,7 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return A new StreamSink.
|
||||
*/
|
||||
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
|
||||
return new DataStreamSink<>(this, new SinkOperator(sinkFunction));
|
||||
return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -104,7 +125,8 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return A new KeyDataStream.
|
||||
*/
|
||||
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
|
||||
return new KeyDataStream<>(this, new KeyByOperator(keyFunction));
|
||||
checkPartitionCall();
|
||||
return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -113,8 +135,8 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return This stream.
|
||||
*/
|
||||
public DataStream<T> broadcast() {
|
||||
this.partition = new BroadcastPartition<>();
|
||||
return this;
|
||||
checkPartitionCall();
|
||||
return setPartition(new BroadcastPartition<>());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -124,19 +146,32 @@ public class DataStream<T> extends Stream<T> {
|
|||
* @return This stream.
|
||||
*/
|
||||
public DataStream<T> partitionBy(Partition<T> partition) {
|
||||
this.partition = partition;
|
||||
return this;
|
||||
checkPartitionCall();
|
||||
return setPartition(partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set parallelism to current transformation.
|
||||
*
|
||||
* @param parallelism The parallelism to set.
|
||||
* @return This stream.
|
||||
* If parent stream is a python stream, we can't call partition related methods
|
||||
* in the java stream.
|
||||
*/
|
||||
public DataStream<T> setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
private void checkPartitionCall() {
|
||||
if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
|
||||
throw new RuntimeException("Partition related methods can't be called on a " +
|
||||
"java stream if parent stream is a python stream.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert this stream as a python stream.
|
||||
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||
*/
|
||||
public PythonDataStream asPythonStream() {
|
||||
return new PythonDataStream(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.JAVA;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.streaming.api.stream;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.operator.impl.SinkOperator;
|
||||
|
||||
/**
|
||||
|
@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
|
|||
*/
|
||||
public class DataStreamSink<T> extends StreamSink<T> {
|
||||
|
||||
public DataStreamSink(DataStream<T> input, SinkOperator sinkOperator) {
|
||||
public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
|
||||
super(input, sinkOperator);
|
||||
this.streamingContext.addSink(this);
|
||||
getStreamingContext().addSink(this);
|
||||
}
|
||||
|
||||
public DataStreamSink<T> setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.JAVA;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,27 +14,26 @@ import java.util.Collection;
|
|||
*/
|
||||
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
|
||||
|
||||
public DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
||||
super(streamingContext, new SourceOperator<>(sourceFunction));
|
||||
super.partition = new RoundRobinPartition<>();
|
||||
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
||||
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
|
||||
}
|
||||
|
||||
public static <T> DataStreamSource<T> fromSource(
|
||||
StreamingContext context, SourceFunction<T> sourceFunction) {
|
||||
return new DataStreamSource<>(context, sourceFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a DataStreamSource source from a collection.
|
||||
*
|
||||
* @param context Stream context.
|
||||
* @param values A collection of values.
|
||||
* @param <T> The type of source data.
|
||||
* @param values A collection of values.
|
||||
* @param <T> The type of source data.
|
||||
* @return A DataStreamSource.
|
||||
*/
|
||||
public static <T> DataStreamSource<T> buildSource(
|
||||
public static <T> DataStreamSource<T> fromCollection(
|
||||
StreamingContext context, Collection<T> values) {
|
||||
return new DataStreamSource(context, new CollectionSourceFunction(values));
|
||||
return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataStreamSource<T> setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
|
|||
|
||||
import io.ray.streaming.api.function.impl.AggregateFunction;
|
||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.KeyPartition;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.operator.impl.ReduceOperator;
|
||||
import io.ray.streaming.python.stream.PythonDataStream;
|
||||
import io.ray.streaming.python.stream.PythonKeyDataStream;
|
||||
|
||||
/**
|
||||
* Represents a DataStream returned by a key-by operation.
|
||||
|
@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
|
|||
* @param <K> Type of the key.
|
||||
* @param <T> Type of the data.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public class KeyDataStream<K, T> extends DataStream<T> {
|
||||
|
||||
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
|
||||
super(input, streamOperator);
|
||||
this.partition = new KeyPartition();
|
||||
super(input, streamOperator, (Partition<T>) new KeyPartition<K, T>());
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a java stream that reference passed python stream.
|
||||
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||
*/
|
||||
public KeyDataStream(PythonDataStream referencedStream) {
|
||||
super(referencedStream);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -41,8 +52,13 @@ public class KeyDataStream<K, T> extends DataStream<T> {
|
|||
return new DataStream<>(this, null);
|
||||
}
|
||||
|
||||
public KeyDataStream<K, T> setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
/**
|
||||
* Convert this stream as a python stream.
|
||||
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||
*/
|
||||
public PythonKeyDataStream asPythonStream() {
|
||||
return new PythonKeyDataStream(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,58 +1,99 @@
|
|||
package io.ray.streaming.api.stream;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import io.ray.streaming.python.stream.PythonStream;
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Abstract base class of all stream types.
|
||||
*
|
||||
* @param <S> Type of stream class
|
||||
* @param <T> Type of the data in the stream.
|
||||
*/
|
||||
public abstract class Stream<T> implements Serializable {
|
||||
protected int id;
|
||||
protected int parallelism = 1;
|
||||
protected StreamOperator operator;
|
||||
protected Stream<T> inputStream;
|
||||
protected StreamingContext streamingContext;
|
||||
protected Partition<T> partition;
|
||||
public abstract class Stream<S extends Stream<S, T>, T>
|
||||
implements Serializable {
|
||||
private final int id;
|
||||
private final StreamingContext streamingContext;
|
||||
private final Stream inputStream;
|
||||
private final StreamOperator operator;
|
||||
private int parallelism = 1;
|
||||
private Partition<T> partition;
|
||||
private Stream originalStream;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||
this(streamingContext, null, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(StreamingContext streamingContext,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
this(streamingContext, null, streamOperator, partition);
|
||||
}
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator) {
|
||||
this(inputStream.getStreamingContext(), inputStream, streamOperator,
|
||||
selectPartition(streamOperator));
|
||||
}
|
||||
|
||||
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
|
||||
this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
|
||||
}
|
||||
|
||||
protected Stream(StreamingContext streamingContext,
|
||||
Stream inputStream,
|
||||
StreamOperator streamOperator,
|
||||
Partition<T> partition) {
|
||||
this.streamingContext = streamingContext;
|
||||
this.inputStream = inputStream;
|
||||
this.operator = streamOperator;
|
||||
this.partition = partition;
|
||||
this.id = streamingContext.generateId();
|
||||
if (streamOperator instanceof PythonOperator) {
|
||||
this.partition = PythonPartition.RoundRobinPartition;
|
||||
} else {
|
||||
this.partition = new RoundRobinPartition<>();
|
||||
if (inputStream != null) {
|
||||
this.parallelism = inputStream.getParallelism();
|
||||
}
|
||||
}
|
||||
|
||||
public Stream(Stream<T> inputStream, StreamOperator streamOperator) {
|
||||
this.inputStream = inputStream;
|
||||
this.parallelism = inputStream.getParallelism();
|
||||
this.streamingContext = this.inputStream.getStreamingContext();
|
||||
this.operator = streamOperator;
|
||||
this.id = streamingContext.generateId();
|
||||
this.partition = selectPartition();
|
||||
/**
|
||||
* Create a proxy stream of original stream.
|
||||
* Changes in new stream will be reflected in original stream and vice versa
|
||||
*/
|
||||
protected Stream(Stream originalStream) {
|
||||
this.originalStream = originalStream;
|
||||
this.id = originalStream.getId();
|
||||
this.streamingContext = originalStream.getStreamingContext();
|
||||
this.inputStream = originalStream.getInputStream();
|
||||
this.operator = originalStream.getOperator();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Partition<T> selectPartition() {
|
||||
if (inputStream instanceof PythonStream) {
|
||||
return PythonPartition.RoundRobinPartition;
|
||||
} else {
|
||||
return new RoundRobinPartition<>();
|
||||
private static <T> Partition<T> selectPartition(Operator operator) {
|
||||
switch (operator.getLanguage()) {
|
||||
case PYTHON:
|
||||
return (Partition<T>) PythonPartition.RoundRobinPartition;
|
||||
case JAVA:
|
||||
return new RoundRobinPartition<>();
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + operator.getLanguage());
|
||||
}
|
||||
}
|
||||
|
||||
public Stream<T> getInputStream() {
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public StreamingContext getStreamingContext() {
|
||||
return streamingContext;
|
||||
}
|
||||
|
||||
public Stream getInputStream() {
|
||||
return inputStream;
|
||||
}
|
||||
|
||||
|
@ -60,32 +101,47 @@ public abstract class Stream<T> implements Serializable {
|
|||
return operator;
|
||||
}
|
||||
|
||||
public void setOperator(StreamOperator operator) {
|
||||
this.operator = operator;
|
||||
}
|
||||
|
||||
public StreamingContext getStreamingContext() {
|
||||
return streamingContext;
|
||||
@SuppressWarnings("unchecked")
|
||||
private S self() {
|
||||
return (S) this;
|
||||
}
|
||||
|
||||
public int getParallelism() {
|
||||
return parallelism;
|
||||
return originalStream != null ? originalStream.getParallelism() : parallelism;
|
||||
}
|
||||
|
||||
public Stream<T> setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
public S setParallelism(int parallelism) {
|
||||
if (originalStream != null) {
|
||||
originalStream.setParallelism(parallelism);
|
||||
} else {
|
||||
this.parallelism = parallelism;
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Partition<T> getPartition() {
|
||||
return partition;
|
||||
return originalStream != null ? originalStream.getPartition() : partition;
|
||||
}
|
||||
|
||||
public void setPartition(Partition<T> partition) {
|
||||
this.partition = partition;
|
||||
@SuppressWarnings("unchecked")
|
||||
protected S setPartition(Partition<T> partition) {
|
||||
if (originalStream != null) {
|
||||
originalStream.setPartition(partition);
|
||||
} else {
|
||||
this.partition = partition;
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
public boolean isProxyStream() {
|
||||
return originalStream != null;
|
||||
}
|
||||
|
||||
public Stream getOriginalStream() {
|
||||
Preconditions.checkArgument(isProxyStream());
|
||||
return originalStream;
|
||||
}
|
||||
|
||||
public abstract Language getLanguage();
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
|
|||
*
|
||||
* @param <T> Type of the input data of this sink.
|
||||
*/
|
||||
public class StreamSink<T> extends Stream<T> {
|
||||
public StreamSink(Stream<T> inputStream, StreamOperator streamOperator) {
|
||||
public abstract class StreamSink<T> extends Stream<StreamSink<T>, T> {
|
||||
public StreamSink(Stream inputStream, StreamOperator streamOperator) {
|
||||
super(inputStream, streamOperator);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,15 +11,15 @@ import java.util.List;
|
|||
*/
|
||||
public class UnionStream<T> extends DataStream<T> {
|
||||
|
||||
private List<DataStream> unionStreams;
|
||||
private List<DataStream<T>> unionStreams;
|
||||
|
||||
public UnionStream(DataStream input, StreamOperator streamOperator, DataStream<T> other) {
|
||||
public UnionStream(DataStream<T> input, StreamOperator streamOperator, DataStream<T> other) {
|
||||
super(input, streamOperator);
|
||||
this.unionStreams = new ArrayList<>();
|
||||
this.unionStreams.add(other);
|
||||
}
|
||||
|
||||
public List<DataStream> getUnionStreams() {
|
||||
public List<DataStream<T>> getUnionStreams() {
|
||||
return unionStreams;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
public boolean isCrossLanguageGraph() {
|
||||
Language language = jobVertexList.get(0).getLanguage();
|
||||
for (JobVertex jobVertex : jobVertexList) {
|
||||
if (jobVertex.getLanguage() != language) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.streaming.jobgraph;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.Stream;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
|
@ -10,8 +11,11 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class JobGraphBuilder {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
|
||||
|
||||
private JobGraph jobGraph;
|
||||
|
||||
|
@ -41,12 +45,19 @@ public class JobGraphBuilder {
|
|||
}
|
||||
|
||||
private void processStream(Stream stream) {
|
||||
while (stream.isProxyStream()) {
|
||||
// Proxy stream and original stream are the same logical stream, both refer to the
|
||||
// same data flow transformation. We should skip proxy stream to avoid applying same
|
||||
// transformation multiple times.
|
||||
LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
|
||||
stream = stream.getOriginalStream();
|
||||
}
|
||||
StreamOperator streamOperator = stream.getOperator();
|
||||
Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
|
||||
"Reference stream should be skipped.");
|
||||
int vertexId = stream.getId();
|
||||
int parallelism = stream.getParallelism();
|
||||
|
||||
StreamOperator streamOperator = stream.getOperator();
|
||||
JobVertex jobVertex = null;
|
||||
|
||||
JobVertex jobVertex;
|
||||
if (stream instanceof StreamSink) {
|
||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
|
||||
Stream parentStream = stream.getInputStream();
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package io.ray.streaming.message;
|
||||
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class KeyRecord<K, T> extends Record<T> {
|
||||
|
||||
private K key;
|
||||
|
@ -17,4 +19,24 @@ public class KeyRecord<K, T> extends Record<T> {
|
|||
public void setKey(K key) {
|
||||
this.key = key;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
if (!super.equals(o)) {
|
||||
return false;
|
||||
}
|
||||
KeyRecord<?, ?> keyRecord = (KeyRecord<?, ?>) o;
|
||||
return Objects.equals(key, keyRecord.key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), key);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
package io.ray.streaming.message;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
public class Message implements Serializable {
|
||||
|
||||
private int taskId;
|
||||
private long batchId;
|
||||
private String stream;
|
||||
private List<Record> recordList;
|
||||
|
||||
public Message(int taskId, long batchId, String stream, List<Record> recordList) {
|
||||
this.taskId = taskId;
|
||||
this.batchId = batchId;
|
||||
this.stream = stream;
|
||||
this.recordList = recordList;
|
||||
}
|
||||
|
||||
public Message(int taskId, long batchId, String stream, Record record) {
|
||||
this.taskId = taskId;
|
||||
this.batchId = batchId;
|
||||
this.stream = stream;
|
||||
this.recordList = Lists.newArrayList(record);
|
||||
}
|
||||
|
||||
public int getTaskId() {
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void setTaskId(int taskId) {
|
||||
this.taskId = taskId;
|
||||
}
|
||||
|
||||
public long getBatchId() {
|
||||
return batchId;
|
||||
}
|
||||
|
||||
public void setBatchId(long batchId) {
|
||||
this.batchId = batchId;
|
||||
}
|
||||
|
||||
public String getStream() {
|
||||
return stream;
|
||||
}
|
||||
|
||||
public void setStream(String stream) {
|
||||
this.stream = stream;
|
||||
}
|
||||
|
||||
public List<Record> getRecordList() {
|
||||
return recordList;
|
||||
}
|
||||
|
||||
public void setRecordList(List<Record> recordList) {
|
||||
this.recordList = recordList;
|
||||
}
|
||||
|
||||
public Record getRecord(int index) {
|
||||
return recordList.get(0);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package io.ray.streaming.message;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class Record<T> implements Serializable {
|
||||
|
@ -27,6 +28,24 @@ public class Record<T> implements Serializable {
|
|||
this.stream = stream;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
Record<?> record = (Record<?>) o;
|
||||
return Objects.equals(stream, record.stream) &&
|
||||
Objects.equals(value, record.value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(stream, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return value.toString();
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package io.ray.streaming.python;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Represents a user defined python function.
|
||||
|
@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
|
|||
*
|
||||
* <p>If the python data stream api is invoked from python, `function` will be not null.</p>
|
||||
* <p>If the python data stream api is invoked from java, `moduleName` and
|
||||
* `className`/`functionName` will be not null.</p>
|
||||
* `functionName` will be not null.</p>
|
||||
* <p>
|
||||
* TODO serialize to bytes using protobuf
|
||||
*/
|
||||
public class PythonFunction implements Function {
|
||||
public enum FunctionInterface {
|
||||
|
@ -38,23 +39,43 @@ public class PythonFunction implements Function {
|
|||
}
|
||||
}
|
||||
|
||||
private byte[] function;
|
||||
private String moduleName;
|
||||
private String className;
|
||||
private String functionName;
|
||||
// null if this function is constructed from moduleName/functionName.
|
||||
private final byte[] function;
|
||||
// null if this function is constructed from serialized python function.
|
||||
private final String moduleName;
|
||||
// null if this function is constructed from serialized python function.
|
||||
private final String functionName;
|
||||
/**
|
||||
* FunctionInterface can be used to validate python function,
|
||||
* and look up operator class from FunctionInterface.
|
||||
*/
|
||||
private String functionInterface;
|
||||
|
||||
private PythonFunction(byte[] function,
|
||||
String moduleName,
|
||||
String className,
|
||||
String functionName) {
|
||||
/**
|
||||
* Create a {@link PythonFunction} from a serialized streaming python function.
|
||||
*
|
||||
* @param function serialized streaming python function from python driver.
|
||||
*/
|
||||
public PythonFunction(byte[] function) {
|
||||
Preconditions.checkNotNull(function);
|
||||
this.function = function;
|
||||
this.moduleName = null;
|
||||
this.functionName = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link PythonFunction} from a moduleName and streaming function name.
|
||||
*
|
||||
* @param moduleName module name of streaming function.
|
||||
* @param functionName function name of streaming function. {@code functionName} is the name
|
||||
* of a python function, or class name of subclass of `ray.streaming.function.`
|
||||
*/
|
||||
public PythonFunction(String moduleName,
|
||||
String functionName) {
|
||||
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
|
||||
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
|
||||
this.function = null;
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
|
@ -70,10 +91,6 @@ public class PythonFunction implements Function {
|
|||
return moduleName;
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public String getFunctionName() {
|
||||
return functionName;
|
||||
}
|
||||
|
@ -82,34 +99,4 @@ public class PythonFunction implements Function {
|
|||
return functionInterface;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link PythonFunction} using python serialized function
|
||||
*
|
||||
* @param function serialized python function sent from python driver
|
||||
*/
|
||||
public static PythonFunction fromFunction(byte[] function) {
|
||||
return new PythonFunction(function, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link PythonFunction} using <code>moduleName</code> and
|
||||
* <code>className</code>.
|
||||
*
|
||||
* @param moduleName python module name
|
||||
* @param className python class name
|
||||
*/
|
||||
public static PythonFunction fromClassName(String moduleName, String className) {
|
||||
return new PythonFunction(null, moduleName, className, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link PythonFunction} using <code>moduleName</code> and
|
||||
* <code>functionName</code>.
|
||||
*
|
||||
* @param moduleName python module name
|
||||
* @param functionName python function name
|
||||
*/
|
||||
public static PythonFunction fromFunctionName(String moduleName, String functionName) {
|
||||
return new PythonFunction(null, moduleName, null, functionName);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package io.ray.streaming.python;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Represents a python partition function.
|
||||
|
@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
|
|||
* If this object is constructed from moduleName and className/functionName,
|
||||
* python worker will use `importlib` to load python partition function.
|
||||
* <p>
|
||||
* TODO serialize to bytes using protobuf
|
||||
*/
|
||||
public class PythonPartition implements Partition {
|
||||
public class PythonPartition implements Partition<Object> {
|
||||
public static final PythonPartition BroadcastPartition = new PythonPartition(
|
||||
"ray.streaming.partition", "BroadcastPartition", null);
|
||||
"ray.streaming.partition", "BroadcastPartition");
|
||||
public static final PythonPartition KeyPartition = new PythonPartition(
|
||||
"ray.streaming.partition", "KeyPartition", null);
|
||||
"ray.streaming.partition", "KeyPartition");
|
||||
public static final PythonPartition RoundRobinPartition = new PythonPartition(
|
||||
"ray.streaming.partition", "RoundRobinPartition", null);
|
||||
"ray.streaming.partition", "RoundRobinPartition");
|
||||
|
||||
private byte[] partition;
|
||||
private String moduleName;
|
||||
private String className;
|
||||
private String functionName;
|
||||
|
||||
public PythonPartition(byte[] partition) {
|
||||
Preconditions.checkNotNull(partition);
|
||||
this.partition = partition;
|
||||
}
|
||||
|
||||
public PythonPartition(String moduleName, String className, String functionName) {
|
||||
/**
|
||||
* Create a python partition from a moduleName and partition function name
|
||||
* @param moduleName module name of python partition
|
||||
* @param functionName function/class name of the partition function.
|
||||
*/
|
||||
public PythonPartition(String moduleName, String functionName) {
|
||||
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
|
||||
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
|
@ -53,10 +60,6 @@ public class PythonPartition implements Partition {
|
|||
return moduleName;
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public String getFunctionName() {
|
||||
return functionName;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.Stream;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
|
@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition;
|
|||
/**
|
||||
* Represents a stream of data whose transformations will be executed in python.
|
||||
*/
|
||||
public class PythonDataStream extends Stream implements PythonStream {
|
||||
public class PythonDataStream extends Stream<PythonDataStream, Object> implements PythonStream {
|
||||
|
||||
protected PythonDataStream(StreamingContext streamingContext,
|
||||
PythonOperator pythonOperator) {
|
||||
super(streamingContext, pythonOperator);
|
||||
}
|
||||
|
||||
protected PythonDataStream(StreamingContext streamingContext,
|
||||
PythonOperator pythonOperator,
|
||||
Partition<Object> partition) {
|
||||
super(streamingContext, pythonOperator, partition);
|
||||
}
|
||||
|
||||
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
||||
super(input, pythonOperator);
|
||||
}
|
||||
|
||||
protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) {
|
||||
super(inputStream, pythonOperator);
|
||||
public PythonDataStream(PythonDataStream input,
|
||||
PythonOperator pythonOperator,
|
||||
Partition<Object> partition) {
|
||||
super(input, pythonOperator, partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a python stream that reference passed java stream.
|
||||
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||
*/
|
||||
public PythonDataStream(DataStream referencedStream) {
|
||||
super(referencedStream);
|
||||
}
|
||||
|
||||
public PythonDataStream map(String moduleName, String funcName) {
|
||||
return map(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
}
|
||||
|
||||
public PythonDataStream flatMap(String moduleName, String funcName) {
|
||||
return flatMap(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a flat-map function to this stream.
|
||||
*
|
||||
|
@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
}
|
||||
|
||||
public PythonDataStream filter(String moduleName, String funcName) {
|
||||
return filter(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a filter function to this stream.
|
||||
*
|
||||
|
@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
}
|
||||
|
||||
public PythonStreamSink sink(String moduleName, String funcName) {
|
||||
return sink(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a sink function and get a StreamSink.
|
||||
*
|
||||
|
@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
return new PythonStreamSink(this, new PythonOperator(func));
|
||||
}
|
||||
|
||||
public PythonKeyDataStream keyBy(String moduleName, String funcName) {
|
||||
return keyBy(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a key-by function to this stream.
|
||||
*
|
||||
|
@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
* @return A new KeyDataStream.
|
||||
*/
|
||||
public PythonKeyDataStream keyBy(PythonFunction func) {
|
||||
checkPartitionCall();
|
||||
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
|
||||
return new PythonKeyDataStream(this, new PythonOperator(func));
|
||||
}
|
||||
|
@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
* @return This stream.
|
||||
*/
|
||||
public PythonDataStream broadcast() {
|
||||
this.partition = PythonPartition.BroadcastPartition;
|
||||
return this;
|
||||
checkPartitionCall();
|
||||
return setPartition(PythonPartition.BroadcastPartition);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream {
|
|||
* @return This stream.
|
||||
*/
|
||||
public PythonDataStream partitionBy(PythonPartition partition) {
|
||||
this.partition = partition;
|
||||
return this;
|
||||
checkPartitionCall();
|
||||
return setPartition(partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set parallelism to current transformation.
|
||||
*
|
||||
* @param parallelism The parallelism to set.
|
||||
* @return This stream.
|
||||
* If parent stream is a python stream, we can't call partition related methods
|
||||
* in the java stream.
|
||||
*/
|
||||
public PythonDataStream setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
private void checkPartitionCall() {
|
||||
if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) {
|
||||
throw new RuntimeException("Partition related methods can't be called on a " +
|
||||
"python stream if parent stream is a java stream.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert this stream as a java stream.
|
||||
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||
*/
|
||||
public DataStream<Object> asJavaStream() {
|
||||
return new DataStream<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.stream.DataStream;
|
||||
import io.ray.streaming.api.stream.KeyDataStream;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
|
@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition;
|
|||
/**
|
||||
* Represents a python DataStream returned by a key-by operation.
|
||||
*/
|
||||
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
|
||||
@SuppressWarnings("unchecked")
|
||||
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
|
||||
|
||||
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
||||
super(input, pythonOperator);
|
||||
this.partition = PythonPartition.KeyPartition;
|
||||
super(input, pythonOperator, PythonPartition.KeyPartition);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a python stream that reference passed python stream.
|
||||
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||
*/
|
||||
public PythonKeyDataStream(DataStream referencedStream) {
|
||||
super(referencedStream);
|
||||
}
|
||||
|
||||
public PythonDataStream reduce(String moduleName, String funcName) {
|
||||
return reduce(new PythonFunction(moduleName, funcName));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
|
|||
return new PythonDataStream(this, new PythonOperator(func));
|
||||
}
|
||||
|
||||
public PythonKeyDataStream setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
/**
|
||||
* Convert this stream as a java stream.
|
||||
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||
*/
|
||||
public KeyDataStream<Object, Object> asJavaStream() {
|
||||
return new KeyDataStream(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.streaming.python.stream;
|
||||
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.python.PythonOperator;
|
||||
|
||||
|
@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator;
|
|||
public class PythonStreamSink extends StreamSink implements PythonStream {
|
||||
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
|
||||
super(input, sinkOperator);
|
||||
this.streamingContext.addSink(this);
|
||||
getStreamingContext().addSink(this);
|
||||
}
|
||||
|
||||
public PythonStreamSink setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -13,17 +13,12 @@ import io.ray.streaming.python.PythonPartition;
|
|||
public class PythonStreamSource extends PythonDataStream implements StreamSource {
|
||||
|
||||
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
|
||||
super(streamingContext, new PythonOperator(sourceFunction));
|
||||
super.partition = PythonPartition.RoundRobinPartition;
|
||||
}
|
||||
|
||||
public PythonStreamSource setParallelism(int parallelism) {
|
||||
this.parallelism = parallelism;
|
||||
return this;
|
||||
super(streamingContext, new PythonOperator(sourceFunction),
|
||||
PythonPartition.RoundRobinPartition);
|
||||
}
|
||||
|
||||
public static PythonStreamSource from(StreamingContext streamingContext,
|
||||
PythonFunction sourceFunction) {
|
||||
PythonFunction sourceFunction) {
|
||||
sourceFunction.setFunctionInterface(FunctionInterface.SOURCE_FUNCTION);
|
||||
return new PythonStreamSource(streamingContext, sourceFunction);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ public class Config {
|
|||
public static final String CHANNEL_TYPE = "channel_type";
|
||||
public static final String MEMORY_CHANNEL = "memory_channel";
|
||||
public static final String NATIVE_CHANNEL = "native_channel";
|
||||
public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL;
|
||||
public static final String CHANNEL_SIZE = "channel_size";
|
||||
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
|
||||
public static final String IS_RECREATE = "streaming.is_recreate";
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
package io.ray.streaming.api.stream;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.operator.impl.MapOperator;
|
||||
import io.ray.streaming.python.stream.PythonDataStream;
|
||||
import io.ray.streaming.python.stream.PythonKeyDataStream;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class StreamTest {
|
||||
|
||||
@Test
|
||||
public void testReferencedDataStream() {
|
||||
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
|
||||
new MapOperator(value -> null));
|
||||
PythonDataStream pythonDataStream = dataStream.asPythonStream();
|
||||
DataStream javaStream = pythonDataStream.asJavaStream();
|
||||
assertEquals(dataStream.getId(), pythonDataStream.getId());
|
||||
assertEquals(dataStream.getId(), javaStream.getId());
|
||||
javaStream.setParallelism(10);
|
||||
assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism());
|
||||
assertEquals(dataStream.getParallelism(), javaStream.getParallelism());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReferencedKeyDataStream() {
|
||||
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
|
||||
new MapOperator(value -> null));
|
||||
KeyDataStream keyDataStream = dataStream.keyBy(value -> null);
|
||||
PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream();
|
||||
KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream();
|
||||
assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId());
|
||||
assertEquals(keyDataStream.getId(), javaKeyDataStream.getId());
|
||||
javaKeyDataStream.setParallelism(10);
|
||||
assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism());
|
||||
assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism());
|
||||
}
|
||||
}
|
|
@ -38,7 +38,7 @@ public class JobGraphBuilderTest {
|
|||
|
||||
public JobGraph buildDataSyncJobGraph() {
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
||||
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||
Lists.newArrayList("a", "b", "c"));
|
||||
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
||||
|
@ -73,7 +73,7 @@ public class JobGraphBuilderTest {
|
|||
|
||||
public JobGraph buildKeyByJobGraph() {
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
||||
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||
Lists.newArrayList("1", "2", "3", "4"));
|
||||
StreamSink streamSink = dataStream.keyBy(x -> x)
|
||||
.sink(x -> LOG.info(x));
|
||||
|
|
|
@ -36,6 +36,11 @@
|
|||
<artifactId>flatbuffers-java</artifactId>
|
||||
<version>1.9.0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
<version>3.0.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
|
@ -56,6 +61,11 @@
|
|||
<artifactId>owner</artifactId>
|
||||
<version>1.0.10</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-all</artifactId>
|
||||
|
@ -71,11 +81,6 @@
|
|||
<artifactId>powermock-api-mockito</artifactId>
|
||||
<version>1.6.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.powermock</groupId>
|
||||
<artifactId>powermock-core</artifactId>
|
||||
<version>1.6.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.powermock</groupId>
|
||||
<artifactId>powermock-module-testng</artifactId>
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package io.ray.streaming.runtime.core.collector;
|
||||
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import io.ray.api.BaseActor;
|
||||
import io.ray.api.RayPyActor;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.api.partition.Partition;
|
||||
import io.ray.streaming.message.Record;
|
||||
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
|
||||
import io.ray.streaming.runtime.serialization.JavaSerializer;
|
||||
import io.ray.streaming.runtime.serialization.Serializer;
|
||||
import io.ray.streaming.runtime.transfer.ChannelID;
|
||||
import io.ray.streaming.runtime.transfer.DataWriter;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory;
|
|||
public class OutputCollector implements Collector<Record> {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
|
||||
|
||||
private Partition partition;
|
||||
private DataWriter writer;
|
||||
private ChannelID[] outputQueues;
|
||||
private final DataWriter writer;
|
||||
private final ChannelID[] outputQueues;
|
||||
private final Collection<BaseActor> targetActors;
|
||||
private final Language[] targetLanguages;
|
||||
private final Partition partition;
|
||||
private final Serializer javaSerializer = new JavaSerializer();
|
||||
private final Serializer crossLangSerializer = new CrossLangSerializer();
|
||||
|
||||
public OutputCollector(Collection<String> outputQueueIds,
|
||||
DataWriter writer,
|
||||
public OutputCollector(DataWriter writer,
|
||||
Collection<String> outputQueueIds,
|
||||
Collection<BaseActor> targetActors,
|
||||
Partition partition) {
|
||||
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
|
||||
this.writer = writer;
|
||||
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
|
||||
this.targetActors = targetActors;
|
||||
this.targetLanguages = targetActors.stream()
|
||||
.map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA)
|
||||
.toArray(Language[]::new);
|
||||
this.partition = partition;
|
||||
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
|
||||
outputQueueIds, this.partition);
|
||||
|
@ -31,9 +45,32 @@ public class OutputCollector implements Collector<Record> {
|
|||
@Override
|
||||
public void collect(Record record) {
|
||||
int[] partitions = this.partition.partition(record, outputQueues.length);
|
||||
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft());
|
||||
ByteBuffer javaBuffer = null;
|
||||
ByteBuffer crossLangBuffer = null;
|
||||
for (int partition : partitions) {
|
||||
writer.write(outputQueues[partition], msgBuffer);
|
||||
if (targetLanguages[partition] == Language.JAVA) {
|
||||
// avoid repeated serialization
|
||||
if (javaBuffer == null) {
|
||||
byte[] bytes = javaSerializer.serialize(record);
|
||||
javaBuffer = ByteBuffer.allocate(1 + bytes.length);
|
||||
javaBuffer.put(Serializer.JAVA_TYPE_ID);
|
||||
// TODO(chaokunyang) remove copy
|
||||
javaBuffer.put(bytes);
|
||||
javaBuffer.flip();
|
||||
}
|
||||
writer.write(outputQueues[partition], javaBuffer.duplicate());
|
||||
} else {
|
||||
// avoid repeated serialization
|
||||
if (crossLangBuffer == null) {
|
||||
byte[] bytes = crossLangSerializer.serialize(record);
|
||||
crossLangBuffer = ByteBuffer.allocate(1 + bytes.length);
|
||||
crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID);
|
||||
// TODO(chaokunyang) remove copy
|
||||
crossLangBuffer.put(bytes);
|
||||
crossLangBuffer.flip();
|
||||
}
|
||||
writer.write(outputQueues[partition], crossLangBuffer.duplicate());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode;
|
|||
import io.ray.streaming.runtime.core.graph.ExecutionTask;
|
||||
import io.ray.streaming.runtime.generated.RemoteCall;
|
||||
import io.ray.streaming.runtime.generated.Streaming;
|
||||
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class GraphPbBuilder {
|
||||
|
@ -74,11 +75,10 @@ public class GraphPbBuilder {
|
|||
private byte[] serializeFunction(Function function) {
|
||||
if (function instanceof PythonFunction) {
|
||||
PythonFunction pyFunc = (PythonFunction) function;
|
||||
// function_bytes, module_name, class_name, function_name, function_interface
|
||||
// function_bytes, module_name, function_name, function_interface
|
||||
return serializer.serialize(Arrays.asList(
|
||||
pyFunc.getFunction(), pyFunc.getModuleName(),
|
||||
pyFunc.getClassName(), pyFunc.getFunctionName(),
|
||||
pyFunc.getFunctionInterface()
|
||||
pyFunc.getFunctionName(), pyFunc.getFunctionInterface()
|
||||
));
|
||||
} else {
|
||||
return new byte[0];
|
||||
|
@ -88,10 +88,10 @@ public class GraphPbBuilder {
|
|||
private byte[] serializePartition(Partition partition) {
|
||||
if (partition instanceof PythonPartition) {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
// partition_bytes, module_name, class_name, function_name
|
||||
// partition_bytes, module_name, function_name
|
||||
return serializer.serialize(Arrays.asList(
|
||||
pythonPartition.getPartition(), pythonPartition.getModuleName(),
|
||||
pythonPartition.getClassName(), pythonPartition.getFunctionName()
|
||||
pythonPartition.getFunctionName()
|
||||
));
|
||||
} else {
|
||||
return new byte[0];
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
package io.ray.streaming.runtime.python;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.primitives.Primitives;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.python.PythonFunction;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import io.ray.streaming.python.stream.PythonStreamSource;
|
||||
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||
import io.ray.streaming.runtime.util.ReflectionUtils;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import org.msgpack.core.Preconditions;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -68,7 +73,7 @@ public class PythonGateway {
|
|||
Preconditions.checkNotNull(streamingContext);
|
||||
try {
|
||||
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
|
||||
streamingContext, PythonFunction.fromFunction(pySourceFunc));
|
||||
streamingContext, new PythonFunction(pySourceFunc));
|
||||
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
|
||||
return serializer.serialize(getReferenceId(pythonStreamSource));
|
||||
} catch (Exception e) {
|
||||
|
@ -84,7 +89,7 @@ public class PythonGateway {
|
|||
}
|
||||
|
||||
public byte[] createPyFunc(byte[] pyFunc) {
|
||||
PythonFunction function = PythonFunction.fromFunction(pyFunc);
|
||||
PythonFunction function = new PythonFunction(pyFunc);
|
||||
referenceMap.put(getReferenceId(function), function);
|
||||
return serializer.serialize(getReferenceId(function));
|
||||
}
|
||||
|
@ -98,15 +103,21 @@ public class PythonGateway {
|
|||
public byte[] callFunction(byte[] paramsBytes) {
|
||||
try {
|
||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||
params = processReferenceParameters(params);
|
||||
params = processParameters(params);
|
||||
LOG.info("callFunction params {}", params);
|
||||
String className = (String) params.get(0);
|
||||
String funcName = (String) params.get(1);
|
||||
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
||||
Method method = ReflectionUtils.findMethod(clz, funcName);
|
||||
Class[] paramsTypes = params.subList(2, params.size()).stream()
|
||||
.map(Object::getClass).toArray(Class[]::new);
|
||||
Method method = findMethod(clz, funcName, paramsTypes);
|
||||
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
if (returnReference(result)) {
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
} else {
|
||||
return serializer.serialize(result);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
@ -115,31 +126,78 @@ public class PythonGateway {
|
|||
public byte[] callMethod(byte[] paramsBytes) {
|
||||
try {
|
||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||
params = processReferenceParameters(params);
|
||||
params = processParameters(params);
|
||||
LOG.info("callMethod params {}", params);
|
||||
Object obj = params.get(0);
|
||||
String methodName = (String) params.get(1);
|
||||
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
|
||||
Class<?> clz = obj.getClass();
|
||||
Class[] paramsTypes = params.subList(2, params.size()).stream()
|
||||
.map(Object::getClass).toArray(Class[]::new);
|
||||
Method method = findMethod(clz, methodName, paramsTypes);
|
||||
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
if (returnReference(result)) {
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
} else {
|
||||
return serializer.serialize(result);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Object> processReferenceParameters(List<Object> params) {
|
||||
return params.stream().map(this::processReferenceParameter)
|
||||
private static Method findMethod(Class<?> cls, String methodName, Class[] paramsTypes) {
|
||||
List<Method> methods = ReflectionUtils.findMethods(cls, methodName);
|
||||
if (methods.size() == 1) {
|
||||
return methods.get(0);
|
||||
}
|
||||
// Convert all params types to primitive types if it's boxed type
|
||||
Class[] unwrappedTypes = Arrays.stream(paramsTypes)
|
||||
.map((Function<Class, Class>) Primitives::unwrap)
|
||||
.toArray(Class[]::new);
|
||||
Optional<Method> any = methods.stream()
|
||||
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
|
||||
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
|
||||
.findAny();
|
||||
Preconditions.checkArgument(any.isPresent(),
|
||||
String.format("Method %s with type %s doesn't exist on class %s",
|
||||
methodName, Arrays.toString(paramsTypes), cls));
|
||||
return any.get();
|
||||
}
|
||||
|
||||
private static boolean returnReference(Object value) {
|
||||
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
|
||||
}
|
||||
|
||||
public byte[] newInstance(byte[] classNameBytes) {
|
||||
String className = (String) serializer.deserialize(classNameBytes);
|
||||
try {
|
||||
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
||||
Object instance = clz.newInstance();
|
||||
referenceMap.put(getReferenceId(instance), instance);
|
||||
return serializer.serialize(getReferenceId(instance));
|
||||
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Create instance for class %s failed", className), e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Object> processParameters(List<Object> params) {
|
||||
return params.stream().map(this::processParameter)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Object processReferenceParameter(Object o) {
|
||||
private Object processParameter(Object o) {
|
||||
if (o instanceof String) {
|
||||
Object value = referenceMap.get(o);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
// Since python can't represent byte/short, we convert all Byte/Short to Integer
|
||||
if (o instanceof Byte || o instanceof Short) {
|
||||
return ((Number) o).intValue();
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler {
|
|||
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobGraph = jobGraph;
|
||||
if (Ray.internal() == null) {
|
||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
|
||||
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
|
||||
boolean hasPythonNode = executionNodes.stream()
|
||||
.allMatch(node -> node.getLanguage() == Language.PYTHON);
|
||||
.anyMatch(node -> node.getLanguage() == Language.PYTHON);
|
||||
RemoteCall.ExecutionGraph executionGraphPb = null;
|
||||
if (hasPythonNode) {
|
||||
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
|
||||
|
|
|
@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule;
|
|||
|
||||
import io.ray.api.BaseActor;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.RayActor;
|
||||
import io.ray.api.RayPyActor;
|
||||
import io.ray.api.function.PyActorClass;
|
||||
import io.ray.streaming.jobgraph.JobEdge;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
|
@ -15,8 +17,11 @@ import java.util.ArrayList;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class TaskAssignerImpl implements TaskAssigner {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class);
|
||||
|
||||
/**
|
||||
* Assign an optimized logical plan to execution graph.
|
||||
|
@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner {
|
|||
|
||||
private BaseActor createWorker(JobVertex jobVertex) {
|
||||
switch (jobVertex.getLanguage()) {
|
||||
case PYTHON:
|
||||
return Ray.createActor(
|
||||
case PYTHON: {
|
||||
RayPyActor worker = Ray.createActor(
|
||||
new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
|
||||
case JAVA:
|
||||
return Ray.createActor(JobWorker::new);
|
||||
LOG.info("Created python worker {}", worker);
|
||||
return worker;
|
||||
}
|
||||
case JAVA: {
|
||||
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
|
||||
LOG.info("Created java worker {}", worker);
|
||||
return worker;
|
||||
}
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + jobVertex.getLanguage());
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
import io.ray.streaming.message.KeyRecord;
|
||||
import io.ray.streaming.message.Record;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A serializer for cross-lang serialization between java/python.
|
||||
* TODO implements a more sophisticated serialization framework
|
||||
*/
|
||||
public class CrossLangSerializer implements Serializer {
|
||||
private static final byte RECORD_TYPE_ID = 0;
|
||||
private static final byte KEY_RECORD_TYPE_ID = 1;
|
||||
|
||||
private MsgPackSerializer msgPackSerializer = new MsgPackSerializer();
|
||||
|
||||
public byte[] serialize(Object object) {
|
||||
Record record = (Record) object;
|
||||
Object value = record.getValue();
|
||||
Class<? extends Record> clz = record.getClass();
|
||||
if (clz == Record.class) {
|
||||
return msgPackSerializer.serialize(Arrays.asList(
|
||||
RECORD_TYPE_ID, record.getStream(), value));
|
||||
} else if (clz == KeyRecord.class) {
|
||||
KeyRecord keyRecord = (KeyRecord) record;
|
||||
Object key = keyRecord.getKey();
|
||||
return msgPackSerializer.serialize(Arrays.asList(
|
||||
KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value));
|
||||
} else {
|
||||
throw new UnsupportedOperationException(
|
||||
String.format("Serialize %s is unsupported.", record));
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Record deserialize(byte[] bytes) {
|
||||
List list = (List) msgPackSerializer.deserialize(bytes);
|
||||
Byte typeId = (Byte) list.get(0);
|
||||
switch (typeId) {
|
||||
case RECORD_TYPE_ID: {
|
||||
String stream = (String) list.get(1);
|
||||
Object value = list.get(2);
|
||||
Record record = new Record(value);
|
||||
record.setStream(stream);
|
||||
return record;
|
||||
}
|
||||
case KEY_RECORD_TYPE_ID: {
|
||||
String stream = (String) list.get(1);
|
||||
Object key = list.get(2);
|
||||
Object value = list.get(3);
|
||||
KeyRecord keyRecord = new KeyRecord(key, value);
|
||||
keyRecord.setStream(stream);
|
||||
return keyRecord;
|
||||
}
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported type " + typeId);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
import io.ray.runtime.serializer.FstSerializer;
|
||||
|
||||
public class JavaSerializer implements Serializer {
|
||||
@Override
|
||||
public byte[] serialize(Object object) {
|
||||
return FstSerializer.encode(object);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T deserialize(byte[] bytes) {
|
||||
return FstSerializer.decode(bytes);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package io.ray.streaming.runtime.python;
|
||||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
import com.google.common.io.BaseEncoding;
|
||||
import java.util.ArrayList;
|
||||
|
@ -31,6 +31,10 @@ public class MsgPackSerializer {
|
|||
Class<?> clz = obj.getClass();
|
||||
if (clz == Boolean.class) {
|
||||
packer.packBoolean((Boolean) obj);
|
||||
} else if (clz == Byte.class) {
|
||||
packer.packByte((Byte) obj);
|
||||
} else if (clz == Short.class) {
|
||||
packer.packShort((Short) obj);
|
||||
} else if (clz == Integer.class) {
|
||||
packer.packInt((Integer) obj);
|
||||
} else if (clz == Long.class) {
|
||||
|
@ -84,7 +88,11 @@ public class MsgPackSerializer {
|
|||
return value.asBooleanValue().getBoolean();
|
||||
case INTEGER:
|
||||
IntegerValue iv = value.asIntegerValue();
|
||||
if (iv.isInIntRange()) {
|
||||
if (iv.isInByteRange()) {
|
||||
return iv.toByte();
|
||||
} else if (iv.isInShortRange()) {
|
||||
return iv.toShort();
|
||||
} else if (iv.isInIntRange()) {
|
||||
return iv.toInt();
|
||||
} else if (iv.isInLongRange()) {
|
||||
return iv.toLong();
|
|
@ -0,0 +1,12 @@
|
|||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
public interface Serializer {
|
||||
byte CROSS_LANG_TYPE_ID = 0;
|
||||
byte JAVA_TYPE_ID = 1;
|
||||
byte PYTHON_TYPE_ID = 2;
|
||||
|
||||
byte[] serialize(Object object);
|
||||
|
||||
<T> T deserialize(byte[] bytes);
|
||||
|
||||
}
|
|
@ -20,7 +20,7 @@ import java.util.Map;
|
|||
*/
|
||||
public class ChannelCreationParametersBuilder {
|
||||
|
||||
public class Parameter {
|
||||
public static class Parameter {
|
||||
|
||||
private ActorId actorId;
|
||||
private FunctionDescriptor asyncFunctionDescriptor;
|
||||
|
@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder {
|
|||
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
|
||||
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
|
||||
} else {
|
||||
Preconditions.checkArgument(false, "Invalid actor type");
|
||||
throw new IllegalArgumentException("Invalid actor type");
|
||||
}
|
||||
parameters.add(parameter);
|
||||
}
|
||||
|
@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder {
|
|||
}
|
||||
|
||||
public String toString() {
|
||||
String str = "";
|
||||
StringBuilder str = new StringBuilder();
|
||||
for (Parameter param : parameters) {
|
||||
str += param.toString();
|
||||
str.append(param.toString());
|
||||
}
|
||||
return str;
|
||||
return str.toString();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ public class DataReader {
|
|||
}
|
||||
long timerInterval = Long.parseLong(
|
||||
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
|
||||
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
||||
String channelType = conf.get(Config.CHANNEL_TYPE);
|
||||
boolean isMock = false;
|
||||
if (Config.MEMORY_CHANNEL.equals(channelType)) {
|
||||
isMock = true;
|
||||
|
|
|
@ -37,7 +37,7 @@ public class DataWriter {
|
|||
Map<String, String> conf) {
|
||||
Preconditions.checkArgument(!outputChannels.isEmpty());
|
||||
Preconditions.checkArgument(outputChannels.size() == toActors.size());
|
||||
ChannelCreationParametersBuilder initialParameters =
|
||||
ChannelCreationParametersBuilder initParameters =
|
||||
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
|
||||
byte[][] outputChannelsBytes = outputChannels.stream()
|
||||
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
|
||||
|
@ -47,13 +47,14 @@ public class DataWriter {
|
|||
for (int i = 0; i < outputChannels.size(); i++) {
|
||||
msgIds[i] = 0;
|
||||
}
|
||||
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
||||
String channelType = conf.get(Config.CHANNEL_TYPE);
|
||||
boolean isMock = false;
|
||||
if (Config.MEMORY_CHANNEL.equals(channelType)) {
|
||||
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) {
|
||||
isMock = true;
|
||||
LOGGER.info("Using memory channel");
|
||||
}
|
||||
this.nativeWriterPtr = createWriterNative(
|
||||
initialParameters,
|
||||
initParameters,
|
||||
outputChannelsBytes,
|
||||
msgIds,
|
||||
channelSize,
|
||||
|
|
|
@ -19,6 +19,7 @@ public class ReflectionUtils {
|
|||
|
||||
/**
|
||||
* For covariant return type, return the most specific method.
|
||||
*
|
||||
* @return all methods named by {@code methodName},
|
||||
*/
|
||||
public static List<Method> findMethods(Class<?> cls, String methodName) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package io.ray.streaming.runtime.worker;
|
||||
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import io.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
|
||||
|
@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext;
|
|||
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
|
||||
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
|
||||
import io.ray.streaming.runtime.worker.tasks.StreamTask;
|
||||
import io.ray.streaming.util.Config;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory;
|
|||
*/
|
||||
public class JobWorker implements Serializable {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
|
||||
// special flag to indicate this actor not ready
|
||||
private static final byte[] NOT_READY_FLAG = new byte[4];
|
||||
|
||||
static {
|
||||
EnvUtil.loadNativeLibraries();
|
||||
|
@ -53,12 +53,11 @@ public class JobWorker implements Serializable {
|
|||
|
||||
this.nodeType = executionNode.getNodeType();
|
||||
this.streamProcessor = ProcessBuilder
|
||||
.buildProcessor(executionNode.getStreamOperator());
|
||||
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
|
||||
.buildProcessor(executionNode.getStreamOperator());
|
||||
LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.",
|
||||
EnvUtil.getJvmPid(), taskId, streamProcessor);
|
||||
|
||||
String channelType = (String) this.config.getOrDefault(
|
||||
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
||||
if (channelType.equals(Config.NATIVE_CHANNEL)) {
|
||||
if (!Ray.getRuntimeContext().isSingleProcess()) {
|
||||
transferHandler = new TransferHandler();
|
||||
}
|
||||
task = createStreamTask();
|
||||
|
@ -124,6 +123,9 @@ public class JobWorker implements Serializable {
|
|||
* and receive result from this actor
|
||||
*/
|
||||
public byte[] onReaderMessageSync(byte[] buffer) {
|
||||
if (transferHandler == null) {
|
||||
return NOT_READY_FLAG;
|
||||
}
|
||||
return transferHandler.onReaderMessageSync(buffer);
|
||||
}
|
||||
|
||||
|
@ -139,6 +141,9 @@ public class JobWorker implements Serializable {
|
|||
* and receive result from this actor
|
||||
*/
|
||||
public byte[] onWriterMessageSync(byte[] buffer) {
|
||||
if (transferHandler == null) {
|
||||
return NOT_READY_FLAG;
|
||||
}
|
||||
return transferHandler.onWriterMessageSync(buffer);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package io.ray.streaming.runtime.worker.tasks;
|
||||
|
||||
import io.ray.runtime.serializer.Serializer;
|
||||
import io.ray.streaming.runtime.core.processor.Processor;
|
||||
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
|
||||
import io.ray.streaming.runtime.serialization.JavaSerializer;
|
||||
import io.ray.streaming.runtime.serialization.Serializer;
|
||||
import io.ray.streaming.runtime.transfer.Message;
|
||||
import io.ray.streaming.runtime.worker.JobWorker;
|
||||
import io.ray.streaming.util.Config;
|
||||
|
@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask {
|
|||
private volatile boolean running = true;
|
||||
private volatile boolean stopped = false;
|
||||
private long readTimeoutMillis;
|
||||
private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
|
||||
private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
|
||||
|
||||
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
|
||||
super(taskId, processor, streamWorker);
|
||||
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
|
||||
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
|
||||
javaSerializer = new JavaSerializer();
|
||||
crossLangSerializer = new CrossLangSerializer();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask {
|
|||
while (running) {
|
||||
Message item = reader.read(readTimeoutMillis);
|
||||
if (item != null) {
|
||||
byte[] bytes = new byte[item.body().remaining()];
|
||||
byte[] bytes = new byte[item.body().remaining() - 1];
|
||||
byte typeId = item.body().get();
|
||||
item.body().get(bytes);
|
||||
Object obj = Serializer.decode(bytes, Object.class);
|
||||
Object obj;
|
||||
if (typeId == Serializer.JAVA_TYPE_ID) {
|
||||
obj = javaSerializer.deserialize(bytes);
|
||||
} else {
|
||||
obj = crossLangSerializer.deserialize(bytes);
|
||||
}
|
||||
processor.process(obj);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public abstract class StreamTask implements Runnable {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
|
||||
|
||||
protected int taskId;
|
||||
|
@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable {
|
|||
String queueSize = worker.getConfig()
|
||||
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
|
||||
queueConf.put(Config.CHANNEL_SIZE, queueSize);
|
||||
String channelType = worker.getConfig()
|
||||
.getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
|
||||
String channelType = Ray.getRuntimeContext().isSingleProcess() ?
|
||||
Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL;
|
||||
queueConf.put(Config.CHANNEL_TYPE, channelType);
|
||||
|
||||
ExecutionGraph executionGraph = worker.getExecutionGraph();
|
||||
|
@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable {
|
|||
LOG.info("Create DataWriter succeed.");
|
||||
writers.put(edge, writer);
|
||||
Partition partition = edge.getPartition();
|
||||
collectors.add(new OutputCollector(channelIDs, writer, partition));
|
||||
collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable {
|
|||
reader = new DataReader(channelIDs, inputActors, queueConf);
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(),
|
||||
worker.getConfig(), executionNode.getParallelism());
|
||||
RuntimeContext runtimeContext = new RayRuntimeContext(
|
||||
worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
|
||||
|
||||
processor.open(collectors, runtimeContext);
|
||||
|
||||
|
|
|
@ -24,11 +24,13 @@ public abstract class BaseUnitTest {
|
|||
|
||||
@BeforeMethod
|
||||
public void testBegin(Method method) {
|
||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>");
|
||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>",
|
||||
method.getDeclaringClass(), method.getName());
|
||||
}
|
||||
|
||||
@AfterMethod
|
||||
public void testEnd(Method method) {
|
||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>");
|
||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>",
|
||||
method.getDeclaringClass(), method.getName());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
|
|||
|
||||
public static JobGraph buildJobGraph() {
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
||||
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||
Lists.newArrayList("a", "b", "c"));
|
||||
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
package io.ray.streaming.runtime.demo;
|
||||
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.FilterFunction;
|
||||
import io.ray.streaming.api.function.impl.MapFunction;
|
||||
import io.ray.streaming.api.stream.DataStreamSource;
|
||||
import io.ray.streaming.runtime.BaseUnitTest;
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class HybridStreamTest extends BaseUnitTest implements Serializable {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class);
|
||||
|
||||
public static class Mapper1 implements MapFunction<Object, Object> {
|
||||
|
||||
@Override
|
||||
public Object map(Object value) {
|
||||
LOG.info("HybridStreamTest Mapper1 {}", value);
|
||||
return value.toString();
|
||||
}
|
||||
}
|
||||
|
||||
public static class Filter1 implements FilterFunction<Object> {
|
||||
|
||||
@Override
|
||||
public boolean filter(Object value) throws Exception {
|
||||
LOG.info("HybridStreamTest Filter1 {}", value);
|
||||
return !value.toString().contains("b");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHybridDataStream() throws InterruptedException {
|
||||
Ray.shutdown();
|
||||
StreamingContext context = StreamingContext.buildContext();
|
||||
DataStreamSource<String> streamSource =
|
||||
DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c"));
|
||||
streamSource
|
||||
.map(x -> x + x)
|
||||
.asPythonStream()
|
||||
.map("ray.streaming.tests.test_hybrid_stream", "map_func1")
|
||||
.filter("ray.streaming.tests.test_hybrid_stream", "filter_func1")
|
||||
.asJavaStream()
|
||||
.sink(x -> System.out.println("HybridStreamTest: " + x));
|
||||
context.execute("HybridStreamTestJob");
|
||||
TimeUnit.SECONDS.sleep(3);
|
||||
context.stop();
|
||||
LOG.info("HybridStreamTest succeed");
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package io.ray.streaming.runtime.demo;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||
|
@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
|||
|
||||
@Test
|
||||
public void testWordCount() {
|
||||
Ray.shutdown();
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
Map<String, String> config = new HashMap<>();
|
||||
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
|
||||
|
@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
|||
streamingContext.withConfig(config);
|
||||
List<String> text = new ArrayList<>();
|
||||
text.add("hello world eagle eagle eagle");
|
||||
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
|
||||
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
|
||||
streamSource
|
||||
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
||||
String[] records = value.split(" ");
|
||||
|
@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
|||
}
|
||||
}
|
||||
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
|
||||
streamingContext.stop();
|
||||
}
|
||||
|
||||
private static class WordAndCount implements Serializable {
|
||||
|
|
|
@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python;
|
|||
import io.ray.streaming.api.stream.StreamSink;
|
||||
import io.ray.streaming.jobgraph.JobGraph;
|
||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
|
|
@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest {
|
|||
|
||||
public JobGraph buildDataSyncPlan() {
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
||||
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||
Lists.newArrayList("a", "b", "c"));
|
||||
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
import static org.testng.Assert.assertTrue;
|
||||
|
||||
import org.apache.commons.lang3.builder.EqualsBuilder;
|
||||
import io.ray.streaming.message.KeyRecord;
|
||||
import io.ray.streaming.message.Record;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class CrossLangSerializerTest {
|
||||
|
||||
@Test
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSerialize() {
|
||||
CrossLangSerializer serializer = new CrossLangSerializer();
|
||||
Record record = new Record("value");
|
||||
record.setStream("stream1");
|
||||
assertTrue(EqualsBuilder.reflectionEquals(record,
|
||||
serializer.deserialize(serializer.serialize(record))));
|
||||
KeyRecord keyRecord = new KeyRecord("key", "value");
|
||||
keyRecord.setStream("stream2");
|
||||
assertEquals(keyRecord,
|
||||
serializer.deserialize(serializer.serialize(keyRecord)));
|
||||
}
|
||||
}
|
|
@ -1,4 +1,7 @@
|
|||
package io.ray.streaming.runtime.python;
|
||||
package io.ray.streaming.runtime.serialization;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
import static org.testng.Assert.assertTrue;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -6,25 +9,37 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.testng.annotations.Test;
|
||||
import static org.testng.Assert.assertEquals;
|
||||
import static org.testng.Assert.assertTrue;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class MsgPackSerializerTest {
|
||||
|
||||
@Test
|
||||
public void testSerializeByte() {
|
||||
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
|
||||
assertEquals(serializer.deserialize(
|
||||
serializer.serialize((byte)1)), (byte)1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerialize() {
|
||||
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
|
||||
assertEquals(serializer.deserialize
|
||||
(serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE);
|
||||
assertEquals(serializer.deserialize(
|
||||
serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE);
|
||||
assertEquals(serializer.deserialize(
|
||||
serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE);
|
||||
|
||||
Map map = new HashMap();
|
||||
List list = new ArrayList<>();
|
||||
list.add(null);
|
||||
list.add(true);
|
||||
list.add(1);
|
||||
list.add(1.0d);
|
||||
list.add("str");
|
||||
map.put("k1", "value1");
|
||||
map.put("k2", 2);
|
||||
map.put("k2", new HashMap<>());
|
||||
map.put("k3", list);
|
||||
byte[] bytes = serializer.serialize(map);
|
||||
Object o = serializer.deserialize(bytes);
|
|
@ -5,6 +5,7 @@ import io.ray.api.Ray;
|
|||
import io.ray.api.RayActor;
|
||||
import io.ray.api.options.ActorCreationOptions;
|
||||
import io.ray.api.options.ActorCreationOptions.Builder;
|
||||
import io.ray.runtime.config.RayConfig;
|
||||
import io.ray.streaming.api.context.StreamingContext;
|
||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||
|
@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
|||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
System.setProperty("ray.run-mode", "CLUSTER");
|
||||
System.setProperty("ray.redirect-output", "true");
|
||||
// ray init
|
||||
RayConfig.reset();
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
|
@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
|||
|
||||
@Test(timeOut = 60000)
|
||||
public void testWordCount() {
|
||||
Ray.shutdown();
|
||||
System.setProperty("ray.resources", "CPU:4,RES-A:4");
|
||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
|
||||
System.setProperty("ray.run-mode", "CLUSTER");
|
||||
System.setProperty("ray.redirect-output", "true");
|
||||
// ray init
|
||||
Ray.init();
|
||||
LOGGER.info("testWordCount");
|
||||
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
|
||||
System.getProperty("ray.run-mode"));
|
||||
|
@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
|||
streamingContext.withConfig(config);
|
||||
List<String> text = new ArrayList<>();
|
||||
text.add("hello world eagle eagle eagle");
|
||||
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
|
||||
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
|
||||
streamSource
|
||||
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
||||
String[] records = value.split(" ");
|
||||
|
@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
|||
serializeResultToFile(resultFile, wordCount);
|
||||
});
|
||||
|
||||
streamingContext.execute("testWordCount");
|
||||
streamingContext.execute("testSQWordCount");
|
||||
|
||||
Map<String, Integer> checkWordCount =
|
||||
(Map<String, Integer>) deserializeResultFromFile(resultFile);
|
||||
|
|
|
@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on
|
|||
|
||||
echo "Running streaming tests."
|
||||
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
|
||||
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml
|
||||
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml ||
|
||||
exit_code=$?
|
||||
if [ -z ${exit_code+x} ]; then
|
||||
exit_code=0
|
||||
fi
|
||||
echo "Streaming TestNG results"
|
||||
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
|
||||
cat /tmp/ray_streaming_java_test_output/testng-results.xml
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import logging
|
||||
import pickle
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import Language
|
||||
from ray.actor import ActorHandle
|
||||
from ray.streaming import function
|
||||
from ray.streaming import message
|
||||
from ray.streaming import partition
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,19 +34,46 @@ class CollectionCollector(Collector):
|
|||
|
||||
|
||||
class OutputCollector(Collector):
|
||||
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
|
||||
def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
|
||||
target_actors: typing.List[ActorHandle],
|
||||
partition_func: partition.Partition):
|
||||
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
||||
self._writer = writer
|
||||
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
||||
self._target_languages = []
|
||||
for actor in target_actors:
|
||||
if actor._ray_actor_language == Language.PYTHON:
|
||||
self._target_languages.append(function.Language.PYTHON)
|
||||
elif actor._ray_actor_language == Language.JAVA:
|
||||
self._target_languages.append(function.Language.JAVA)
|
||||
else:
|
||||
raise Exception("Unsupported language {}"
|
||||
.format(actor._ray_actor_language))
|
||||
self._partition_func = partition_func
|
||||
self.python_serializer = serialization.PythonSerializer()
|
||||
self.cross_lang_serializer = serialization.CrossLangSerializer()
|
||||
logger.info(
|
||||
"Create OutputCollector, channel_ids {}, partition_func {}".format(
|
||||
channel_ids, partition_func))
|
||||
|
||||
def collect(self, record):
|
||||
partitions = self._partition_func.partition(record,
|
||||
len(self._channel_ids))
|
||||
serialized_message = pickle.dumps(record)
|
||||
partitions = self._partition_func \
|
||||
.partition(record, len(self._channel_ids))
|
||||
python_buffer = None
|
||||
cross_lang_buffer = None
|
||||
for partition_index in partitions:
|
||||
self._writer.write(self._channel_ids[partition_index],
|
||||
serialized_message)
|
||||
if self._target_languages[partition_index] == \
|
||||
function.Language.PYTHON:
|
||||
# avoid repeated serialization
|
||||
if python_buffer is None:
|
||||
python_buffer = self.python_serializer.serialize(record)
|
||||
self._writer.write(
|
||||
self._channel_ids[partition_index],
|
||||
serialization._PYTHON_TYPE_ID + python_buffer)
|
||||
else:
|
||||
# avoid repeated serialization
|
||||
if cross_lang_buffer is None:
|
||||
cross_lang_buffer = self.cross_lang_serializer.serialize(
|
||||
record)
|
||||
self._writer.write(
|
||||
self._channel_ids[partition_index],
|
||||
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray.streaming import function
|
||||
from ray.streaming import partition
|
||||
|
@ -19,7 +19,6 @@ class Stream(ABC):
|
|||
self.streaming_context = input_stream.streaming_context
|
||||
else:
|
||||
self.streaming_context = streaming_context
|
||||
self.parallelism = 1
|
||||
|
||||
def get_streaming_context(self):
|
||||
return self.streaming_context
|
||||
|
@ -29,7 +28,8 @@ class Stream(ABC):
|
|||
Returns:
|
||||
the parallelism of this transformation
|
||||
"""
|
||||
return self.parallelism
|
||||
return self._gateway_client(). \
|
||||
call_method(self._j_stream, "getParallelism")
|
||||
|
||||
def set_parallelism(self, parallelism: int):
|
||||
"""Sets the parallelism of this transformation
|
||||
|
@ -40,7 +40,6 @@ class Stream(ABC):
|
|||
Returns:
|
||||
self
|
||||
"""
|
||||
self.parallelism = parallelism
|
||||
self._gateway_client(). \
|
||||
call_method(self._j_stream, "setParallelism", parallelism)
|
||||
return self
|
||||
|
@ -60,6 +59,10 @@ class Stream(ABC):
|
|||
return self._gateway_client(). \
|
||||
call_method(self._j_stream, "getId")
|
||||
|
||||
@abstractmethod
|
||||
def get_language(self):
|
||||
pass
|
||||
|
||||
def _gateway_client(self):
|
||||
return self.get_streaming_context()._gateway_client
|
||||
|
||||
|
@ -75,6 +78,9 @@ class DataStream(Stream):
|
|||
super().__init__(
|
||||
input_stream, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
def map(self, func):
|
||||
"""
|
||||
Applies a Map transformation on a :class:`DataStream`.
|
||||
|
@ -158,6 +164,7 @@ class DataStream(Stream):
|
|||
Returns:
|
||||
A KeyDataStream
|
||||
"""
|
||||
self._check_partition_call()
|
||||
if not isinstance(func, function.KeyFunction):
|
||||
func = function.SimpleKeyFunction(func)
|
||||
j_func = self._gateway_client().create_py_func(
|
||||
|
@ -175,6 +182,7 @@ class DataStream(Stream):
|
|||
Returns:
|
||||
The DataStream with broadcast partitioning set.
|
||||
"""
|
||||
self._check_partition_call()
|
||||
self._gateway_client().call_method(self._j_stream, "broadcast")
|
||||
return self
|
||||
|
||||
|
@ -191,6 +199,7 @@ class DataStream(Stream):
|
|||
Returns:
|
||||
The DataStream with specified partitioning set.
|
||||
"""
|
||||
self._check_partition_call()
|
||||
if not isinstance(partition_func, partition.Partition):
|
||||
partition_func = partition.SimplePartition(partition_func)
|
||||
j_partition = self._gateway_client().create_py_func(
|
||||
|
@ -199,6 +208,16 @@ class DataStream(Stream):
|
|||
call_method(self._j_stream, "partitionBy", j_partition)
|
||||
return self
|
||||
|
||||
def _check_partition_call(self):
|
||||
"""
|
||||
If parent stream is a java stream, we can't call partition related
|
||||
methods in the python stream
|
||||
"""
|
||||
if self.input_stream is not None and \
|
||||
self.input_stream.get_language() == function.Language.JAVA:
|
||||
raise Exception("Partition related methods can't be called on a "
|
||||
"python stream if parent stream is a java stream.")
|
||||
|
||||
def sink(self, func):
|
||||
"""
|
||||
Create a StreamSink with the given sink.
|
||||
|
@ -217,8 +236,97 @@ class DataStream(Stream):
|
|||
call_method(self._j_stream, "sink", j_func)
|
||||
return StreamSink(self, j_stream, func)
|
||||
|
||||
def as_java_stream(self):
|
||||
"""
|
||||
Convert this stream as a java JavaDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asJavaStream")
|
||||
return JavaDataStream(self, j_stream)
|
||||
|
||||
class KeyDataStream(Stream):
|
||||
|
||||
class JavaDataStream(Stream):
|
||||
"""
|
||||
Represents a stream of data which applies a transformation executed by
|
||||
java. It's also a wrapper of java
|
||||
`org.ray.streaming.api.stream.DataStream`
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream, streaming_context=None):
|
||||
super().__init__(
|
||||
input_stream, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
||||
def map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.map"""
|
||||
return JavaDataStream(self, self._unary_call("map", java_func_class))
|
||||
|
||||
def flat_map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
|
||||
return JavaDataStream(self, self._unary_call("flatMap",
|
||||
java_func_class))
|
||||
|
||||
def filter(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.filter"""
|
||||
return JavaDataStream(self, self._unary_call("filter",
|
||||
java_func_class))
|
||||
|
||||
def key_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
|
||||
self._check_partition_call()
|
||||
return JavaKeyDataStream(self,
|
||||
self._unary_call("keyBy", java_func_class))
|
||||
|
||||
def broadcast(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("broadcast", java_func_class))
|
||||
|
||||
def partition_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("partitionBy", java_func_class))
|
||||
|
||||
def sink(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.sink"""
|
||||
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
|
||||
|
||||
def as_python_stream(self):
|
||||
"""
|
||||
Convert this stream as a python DataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asPythonStream")
|
||||
return DataStream(self, j_stream)
|
||||
|
||||
def _check_partition_call(self):
|
||||
"""
|
||||
If parent stream is a python stream, we can't call partition related
|
||||
methods in the java stream
|
||||
"""
|
||||
if self.input_stream is not None and \
|
||||
self.input_stream.get_language() == function.Language.PYTHON:
|
||||
raise Exception("Partition related methods can't be called on a"
|
||||
"java stream if parent stream is a python stream.")
|
||||
|
||||
def _unary_call(self, func_name, java_func_class):
|
||||
j_func = self._gateway_client().new_instance(java_func_class)
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, func_name, j_func)
|
||||
return j_stream
|
||||
|
||||
|
||||
class KeyDataStream(DataStream):
|
||||
"""Represents a DataStream returned by a key-by operation.
|
||||
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
|
||||
"""
|
||||
|
@ -251,6 +359,43 @@ class KeyDataStream(Stream):
|
|||
call_method(self._j_stream, "reduce", j_func)
|
||||
return DataStream(self, j_stream)
|
||||
|
||||
def as_java_stream(self):
|
||||
"""
|
||||
Convert this stream as a java KeyDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asJavaStream")
|
||||
return JavaKeyDataStream(self, j_stream)
|
||||
|
||||
|
||||
class JavaKeyDataStream(JavaDataStream):
|
||||
"""
|
||||
Represents a DataStream returned by a key-by operation in java.
|
||||
Wrapper of org.ray.streaming.api.stream.KeyDataStream
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def reduce(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
|
||||
return JavaDataStream(self,
|
||||
super()._unary_call("reduce", java_func_class))
|
||||
|
||||
def as_python_stream(self):
|
||||
"""
|
||||
Convert this stream as a python KeyDataStream.
|
||||
The converted stream and this stream are the same logical stream,
|
||||
which has same stream id. Changes in converted stream will be reflected
|
||||
in this stream and vice versa.
|
||||
"""
|
||||
j_stream = self._gateway_client(). \
|
||||
call_method(self._j_stream, "asPythonStream")
|
||||
return KeyDataStream(self, j_stream)
|
||||
|
||||
|
||||
class StreamSource(DataStream):
|
||||
"""Represents a source of the DataStream.
|
||||
|
@ -261,9 +406,12 @@ class StreamSource(DataStream):
|
|||
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||
self.source_func = source_func
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
@staticmethod
|
||||
def build_source(streaming_context, func):
|
||||
"""Build a StreamSource source from a collection.
|
||||
"""Build a StreamSource source from a source function.
|
||||
Args:
|
||||
streaming_context: Stream context
|
||||
func: A instance of `SourceFunction`
|
||||
|
@ -275,6 +423,34 @@ class StreamSource(DataStream):
|
|||
return StreamSource(j_stream, streaming_context, func)
|
||||
|
||||
|
||||
class JavaStreamSource(JavaDataStream):
|
||||
"""Represents a source of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
|
||||
"""
|
||||
|
||||
def __init__(self, j_stream, streaming_context):
|
||||
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
||||
@staticmethod
|
||||
def build_source(streaming_context, java_source_func_class):
|
||||
"""Build a java StreamSource source from a java source function.
|
||||
Args:
|
||||
streaming_context: Stream context
|
||||
java_source_func_class: qualified class name of java SourceFunction
|
||||
Returns:
|
||||
A java StreamSource
|
||||
"""
|
||||
j_func = streaming_context._gateway_client() \
|
||||
.new_instance(java_source_func_class)
|
||||
j_stream = streaming_context._gateway_client() \
|
||||
.call_function("org.ray.streaming.api.stream.DataStreamSource"
|
||||
"fromSource", streaming_context._j_ctx, j_func)
|
||||
return JavaStreamSource(j_stream, streaming_context)
|
||||
|
||||
|
||||
class StreamSink(Stream):
|
||||
"""Represents a sink of the DataStream.
|
||||
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
|
||||
|
@ -282,3 +458,18 @@ class StreamSink(Stream):
|
|||
|
||||
def __init__(self, input_stream, j_stream, func):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.PYTHON
|
||||
|
||||
|
||||
class JavaStreamSink(Stream):
|
||||
"""Represents a sink of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.StreamSink
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def get_language(self):
|
||||
return function.Language.JAVA
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
import enum
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import cloudpickle
|
||||
from ray.streaming.runtime import gateway_client
|
||||
|
||||
|
||||
class Language(enum.Enum):
|
||||
JAVA = 0
|
||||
PYTHON = 1
|
||||
|
||||
|
||||
class Function(ABC):
|
||||
"""The base interface for all user-defined functions."""
|
||||
|
||||
|
@ -60,6 +66,7 @@ class MapFunction(Function):
|
|||
for each input element.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def map(self, value):
|
||||
pass
|
||||
|
||||
|
@ -70,6 +77,7 @@ class FlatMapFunction(Function):
|
|||
transform them into zero, one, or more elements.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def flat_map(self, value, collector):
|
||||
"""Takes an element from the input data set and transforms it into zero,
|
||||
one, or more elements.
|
||||
|
@ -87,6 +95,7 @@ class FilterFunction(Function):
|
|||
The predicate decides whether to keep the element, or to discard it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, value):
|
||||
"""The filter function that evaluates the predicate.
|
||||
|
||||
|
@ -106,6 +115,7 @@ class KeyFunction(Function):
|
|||
deterministic key for that object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def key_by(self, value):
|
||||
"""User-defined function that deterministically extracts the key from
|
||||
an object.
|
||||
|
@ -126,6 +136,7 @@ class ReduceFunction(Function):
|
|||
them into one.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reduce(self, old_value, new_value):
|
||||
"""
|
||||
The core method of ReduceFunction, combining two values into one value
|
||||
|
@ -145,6 +156,7 @@ class ReduceFunction(Function):
|
|||
class SinkFunction(Function):
|
||||
"""Interface for implementing user defined sink functionality."""
|
||||
|
||||
@abstractmethod
|
||||
def sink(self, value):
|
||||
"""Writes the given value to the sink. This function is called for
|
||||
every record."""
|
||||
|
@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
|
|||
Returns:
|
||||
a streaming function
|
||||
"""
|
||||
function_bytes, module_name, class_name, function_name, function_interface\
|
||||
assert len(descriptor_func_bytes) > 0
|
||||
function_bytes, module_name, function_name, function_interface\
|
||||
= gateway_client.deserialize(descriptor_func_bytes)
|
||||
if function_bytes:
|
||||
return deserialize(function_bytes)
|
||||
|
@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
|
|||
assert function_interface
|
||||
function_interface = getattr(sys.modules[__name__], function_interface)
|
||||
mod = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
assert function_name is None
|
||||
cls = getattr(mod, class_name)
|
||||
assert issubclass(cls, function_interface)
|
||||
return cls()
|
||||
else:
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python function, user function is a simple python
|
||||
# function, which will be wrapped as a SimpleXXXFunction.
|
||||
# If func is a python class, user function is a sub class
|
||||
# of XXXFunction.
|
||||
if inspect.isfunction(func):
|
||||
simple_func_class = _get_simple_function_class(function_interface)
|
||||
return simple_func_class(func)
|
||||
else:
|
||||
assert issubclass(func, function_interface)
|
||||
return func()
|
||||
|
||||
|
||||
def _get_simple_function_class(function_interface):
|
||||
|
|
|
@ -8,6 +8,14 @@ class Record:
|
|||
def __repr__(self):
|
||||
return "Record(%s)".format(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) is type(other):
|
||||
return (self.stream, self.value) == (other.stream, other.value)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.stream, self.value))
|
||||
|
||||
|
||||
class KeyRecord(Record):
|
||||
"""Data record in a keyed data stream"""
|
||||
|
@ -15,3 +23,12 @@ class KeyRecord(Record):
|
|||
def __init__(self, key, value):
|
||||
super().__init__(value)
|
||||
self.key = key
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) is type(other):
|
||||
return (self.stream, self.key, self.value) ==\
|
||||
(other.stream, other.key, other.value)
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.stream, self.key, self.value))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import importlib
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import cloudpickle
|
||||
|
@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
|
|||
Returns:
|
||||
partition function
|
||||
"""
|
||||
partition_bytes, module_name, class_name, function_name =\
|
||||
assert len(descriptor_partition_bytes) > 0
|
||||
partition_bytes, module_name, function_name =\
|
||||
gateway_client.deserialize(descriptor_partition_bytes)
|
||||
if partition_bytes:
|
||||
return deserialize(partition_bytes)
|
||||
else:
|
||||
assert module_name
|
||||
mod = importlib.import_module(module_name)
|
||||
# If class_name is not None, user partition is a sub class
|
||||
# of Partition.
|
||||
# If function_name is not None, user partition is a simple python
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python function, user partition is a simple python
|
||||
# function, which will be wrapped as a SimplePartition.
|
||||
if class_name:
|
||||
assert function_name is None
|
||||
cls = getattr(mod, class_name)
|
||||
return cls()
|
||||
else:
|
||||
assert function_name
|
||||
func = getattr(mod, function_name)
|
||||
# If func is a python class, user partition is a sub class
|
||||
# of Partition.
|
||||
if inspect.isfunction(func):
|
||||
return SimplePartition(func)
|
||||
else:
|
||||
assert issubclass(func, Partition)
|
||||
return func()
|
||||
|
|
|
@ -55,6 +55,11 @@ class GatewayClient:
|
|||
call = self._python_gateway_actor.callMethod.remote(java_params)
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
def new_instance(self, java_class_name):
|
||||
call = self._python_gateway_actor.newInstance.remote(
|
||||
serialize(java_class_name))
|
||||
return deserialize(ray.get(call))
|
||||
|
||||
|
||||
def serialize(obj) -> bytes:
|
||||
"""Serialize a python object which can be deserialized by `PythonGateway`
|
||||
|
|
|
@ -53,7 +53,9 @@ class ExecutionEdge:
|
|||
self.src_node_id = edge_pb.src_node_id
|
||||
self.target_node_id = edge_pb.target_node_id
|
||||
partition_bytes = edge_pb.partition
|
||||
if language == Language.PYTHON:
|
||||
# Sink node doesn't have partition function,
|
||||
# so we only deserialize partition_bytes when it's not None or empty
|
||||
if language == Language.PYTHON and partition_bytes:
|
||||
self.partition = partition.load_partition(partition_bytes)
|
||||
|
||||
|
||||
|
|
57
streaming/python/runtime/serialization.py
Normal file
57
streaming/python/runtime/serialization.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
import msgpack
|
||||
from ray.streaming import message
|
||||
|
||||
_RECORD_TYPE_ID = 0
|
||||
_KEY_RECORD_TYPE_ID = 1
|
||||
_CROSS_LANG_TYPE_ID = b"0"
|
||||
_JAVA_TYPE_ID = b"1"
|
||||
_PYTHON_TYPE_ID = b"2"
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self, obj):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, serialized_bytes):
|
||||
pass
|
||||
|
||||
|
||||
class PythonSerializer(Serializer):
|
||||
def serialize(self, obj):
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def deserialize(self, serialized_bytes):
|
||||
return pickle.loads(serialized_bytes)
|
||||
|
||||
|
||||
class CrossLangSerializer(Serializer):
|
||||
"""Serialize stream element between java/python"""
|
||||
|
||||
def serialize(self, obj):
|
||||
if type(obj) is message.Record:
|
||||
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
|
||||
elif type(obj) is message.KeyRecord:
|
||||
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
|
||||
else:
|
||||
raise Exception("Unsupported value {}".format(obj))
|
||||
return msgpack.packb(fields, use_bin_type=True)
|
||||
|
||||
def deserialize(self, data):
|
||||
fields = msgpack.unpackb(data, raw=False)
|
||||
if fields[0] == _RECORD_TYPE_ID:
|
||||
stream, value = fields[1:]
|
||||
record = message.Record(value)
|
||||
record.stream = stream
|
||||
return record
|
||||
elif fields[0] == _KEY_RECORD_TYPE_ID:
|
||||
stream, key, value = fields[1:]
|
||||
key_record = message.KeyRecord(key, value)
|
||||
key_record.stream = stream
|
||||
return key_record
|
||||
else:
|
||||
raise Exception("Unsupported type id {}, type {}".format(
|
||||
fields[0], type(fields[0])))
|
|
@ -1,11 +1,13 @@
|
|||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.serialization import \
|
||||
PythonSerializer, CrossLangSerializer
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -38,36 +40,40 @@ class StreamTask(ABC):
|
|||
# writers
|
||||
collectors = []
|
||||
for edge in execution_node.output_edges:
|
||||
output_actor_ids = {}
|
||||
output_actors_map = {}
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.target_node_id)
|
||||
for target_task_id, target_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
execution_graph.build_time())
|
||||
output_actor_ids[channel_name] = target_actor
|
||||
if len(output_actor_ids) > 0:
|
||||
channel_ids = list(output_actor_ids.keys())
|
||||
to_actor_ids = list(output_actor_ids.values())
|
||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
||||
logger.info("Create DataWriter succeed.")
|
||||
output_actors_map[channel_name] = target_actor
|
||||
if len(output_actors_map) > 0:
|
||||
channel_ids = list(output_actors_map.keys())
|
||||
target_actors = list(output_actors_map.values())
|
||||
logger.info(
|
||||
"Create DataWriter channel_ids {}, target_actors {}."
|
||||
.format(channel_ids, target_actors))
|
||||
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(channel_ids, writer, edge.partition))
|
||||
OutputCollector(writer, channel_ids, target_actors,
|
||||
edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_ids = {}
|
||||
input_actor_map = {}
|
||||
for edge in execution_node.input_edges:
|
||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||
edge.src_node_id)
|
||||
for src_task_id, src_actor in task_id2_worker.items():
|
||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||
execution_graph.build_time())
|
||||
input_actor_ids[channel_name] = src_actor
|
||||
if len(input_actor_ids) > 0:
|
||||
channel_ids = list(input_actor_ids.keys())
|
||||
from_actor_ids = list(input_actor_ids.values())
|
||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
||||
input_actor_map[channel_name] = src_actor
|
||||
if len(input_actor_map) > 0:
|
||||
channel_ids = list(input_actor_map.keys())
|
||||
from_actors = list(input_actor_map.values())
|
||||
logger.info("Create DataReader, channels {}, input_actors {}."
|
||||
.format(channel_ids, from_actors))
|
||||
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
|
@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
|
|||
self.read_timeout_millis = \
|
||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||
self.python_serializer = PythonSerializer()
|
||||
self.cross_lang_serializer = CrossLangSerializer()
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
|
|||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
msg = pickle.loads(msg_data)
|
||||
type_id = msg_data[:1]
|
||||
if (type_id == serialization._PYTHON_TYPE_ID):
|
||||
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||
else:
|
||||
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
|
|
|
@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
|
|||
wrap initial parameters needed by a streaming queue
|
||||
"""
|
||||
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
|
||||
"([B)V")
|
||||
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
|
||||
"([B)[B")
|
||||
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
|
||||
"([B)V")
|
||||
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
|
||||
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
|
||||
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
|
||||
"([B)[B")
|
||||
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
|
||||
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
|
||||
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
|
||||
|
|
|
@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# special flag to indicate this actor not ready
|
||||
_NOT_READY_FLAG_ = b" " * 4
|
||||
|
||||
|
||||
@ray.remote
|
||||
class JobWorker(object):
|
||||
|
@ -66,23 +69,31 @@ class JobWorker(object):
|
|||
type(self.stream_processor))
|
||||
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by upstream queue writer to send data message to downstream
|
||||
queue reader.
|
||||
"""
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by upstream queue writer to send control message to downstream
|
||||
downstream queue reader.
|
||||
"""
|
||||
if self.reader_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def on_writer_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by downstream queue reader to send notify message to
|
||||
upstream queue writer.
|
||||
"""
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
"""Called by downstream queue reader to send control message to
|
||||
upstream queue writer.
|
||||
"""
|
||||
if self.writer_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
|
|
@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
|
|||
|
||||
|
||||
def test_load_function():
|
||||
# function_bytes, module_name, class_name, function_name,
|
||||
# function_bytes, module_name, function_name/class_name,
|
||||
# function_interface
|
||||
descriptor_func_bytes = gateway_client.serialize(
|
||||
[None, __name__, MapFunc.__name__, None, "MapFunction"])
|
||||
[None, __name__, MapFunc.__name__, "MapFunction"])
|
||||
func = function.load_function(descriptor_func_bytes)
|
||||
assert type(func) is MapFunc
|
||||
|
|
70
streaming/python/tests/test_hybrid_stream.py
Normal file
70
streaming/python/tests/test_hybrid_stream.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import json
|
||||
import ray
|
||||
from ray.streaming import StreamingContext
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
|
||||
def map_func1(x):
|
||||
print("HybridStreamTest map_func1", x)
|
||||
return str(x)
|
||||
|
||||
|
||||
def filter_func1(x):
|
||||
print("HybridStreamTest filter_func1", x)
|
||||
return "b" not in x
|
||||
|
||||
|
||||
def sink_func1(x):
|
||||
print("HybridStreamTest sink_func1 value:", x)
|
||||
|
||||
|
||||
def test_hybrid_stream():
|
||||
subprocess.check_call(
|
||||
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
|
||||
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
jar_path = os.path.join(
|
||||
current_dir,
|
||||
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
|
||||
jar_path = os.path.abspath(jar_path)
|
||||
print("jar_path", jar_path)
|
||||
java_worker_options = json.dumps(["-classpath", jar_path])
|
||||
print("java_worker_options", java_worker_options)
|
||||
assert not ray.is_initialized()
|
||||
ray.init(
|
||||
load_code_from_local=True,
|
||||
include_java=True,
|
||||
java_worker_options=java_worker_options,
|
||||
_internal_config=json.dumps({
|
||||
"num_workers_per_process_java": 1
|
||||
}))
|
||||
|
||||
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
|
||||
if os.path.exists(sink_file):
|
||||
os.remove(sink_file)
|
||||
|
||||
def sink_func(x):
|
||||
print("HybridStreamTest", x)
|
||||
with open(sink_file, "a") as f:
|
||||
f.write(str(x))
|
||||
|
||||
ctx = StreamingContext.Builder().build()
|
||||
ctx.from_values("a", "b", "c") \
|
||||
.as_java_stream() \
|
||||
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
|
||||
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
|
||||
.as_python_stream() \
|
||||
.sink(sink_func)
|
||||
ctx.submit("HybridStreamTest")
|
||||
import time
|
||||
time.sleep(3)
|
||||
ray.shutdown()
|
||||
with open(sink_file, "r") as f:
|
||||
result = f.read()
|
||||
assert "a" in result
|
||||
assert "b" not in result
|
||||
assert "c" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_stream()
|
13
streaming/python/tests/test_serialization.py
Normal file
13
streaming/python/tests/test_serialization.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from ray.streaming.runtime.serialization import CrossLangSerializer
|
||||
from ray.streaming.message import Record, KeyRecord
|
||||
|
||||
|
||||
def test_serialize():
|
||||
serializer = CrossLangSerializer()
|
||||
record = Record("value")
|
||||
record.stream = "stream1"
|
||||
key_record = KeyRecord("key", "value")
|
||||
key_record.stream = "stream2"
|
||||
assert record == serializer.deserialize(serializer.serialize(record))
|
||||
assert key_record == serializer.\
|
||||
deserialize(serializer.serialize(key_record))
|
31
streaming/python/tests/test_stream.py
Normal file
31
streaming/python/tests/test_stream.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import ray
|
||||
from ray.streaming import StreamingContext
|
||||
|
||||
|
||||
def test_data_stream():
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
ctx = StreamingContext.Builder().build()
|
||||
stream = ctx.from_values(1, 2, 3)
|
||||
java_stream = stream.as_java_stream()
|
||||
python_stream = java_stream.as_python_stream()
|
||||
assert stream.get_id() == java_stream.get_id()
|
||||
assert stream.get_id() == python_stream.get_id()
|
||||
python_stream.set_parallelism(10)
|
||||
assert stream.get_parallelism() == java_stream.get_parallelism()
|
||||
assert stream.get_parallelism() == python_stream.get_parallelism()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_key_data_stream():
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
ctx = StreamingContext.Builder().build()
|
||||
key_stream = ctx.from_values(
|
||||
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
|
||||
java_stream = key_stream.as_java_stream()
|
||||
python_stream = java_stream.as_python_stream()
|
||||
assert key_stream.get_id() == java_stream.get_id()
|
||||
assert key_stream.get_id() == python_stream.get_id()
|
||||
python_stream.set_parallelism(10)
|
||||
assert key_stream.get_parallelism() == java_stream.get_parallelism()
|
||||
assert key_stream.get_parallelism() == python_stream.get_parallelism()
|
||||
ray.shutdown()
|
|
@ -32,7 +32,9 @@ def test_simple_word_count():
|
|||
|
||||
def sink_func(x):
|
||||
with open(sink_file, "a") as f:
|
||||
f.write("{}:{},".format(x[0], x[1]))
|
||||
line = "{}:{},".format(x[0], x[1])
|
||||
print("sink_func", line)
|
||||
f.write(line)
|
||||
|
||||
ctx.from_values("a", "b", "c") \
|
||||
.set_parallelism(1) \
|
||||
|
|
|
@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
|
|||
return reinterpret_cast<jlong>(reader_client);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
|
||||
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
|
||||
}
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
|
||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||
|
|
Loading…
Add table
Reference in a new issue