[Streaming] Streaming Cross-Lang API (#7464)

This commit is contained in:
chaokunyang 2020-04-29 13:42:08 +08:00 committed by GitHub
parent 101255f782
commit 91f630f709
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
72 changed files with 1612 additions and 408 deletions

View file

@ -542,6 +542,7 @@ def init(address=None,
raylet_socket_name=None,
temp_dir=None,
load_code_from_local=False,
java_worker_options=None,
use_pickle=True,
_internal_config=None,
lru_evict=False):
@ -651,6 +652,7 @@ def init(address=None,
conventional location, e.g., "/tmp/ray".
load_code_from_local: Whether code should be loaded from a local
module or from the GCS.
java_worker_options: Overwrite the options to start Java workers.
use_pickle: Deprecated.
_internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY.
@ -758,6 +760,7 @@ def init(address=None,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local,
java_worker_options=java_worker_options,
_internal_config=_internal_config,
)
# Start the Ray processes. We set shutdown_at_exit=False because we
@ -808,6 +811,9 @@ def init(address=None,
if raylet_socket_name is not None:
raise ValueError("When connecting to an existing cluster, "
"raylet_socket_name must not be provided.")
if java_worker_options is not None:
raise ValueError("When connecting to an existing cluster, "
"java_worker_options must not be provided.")
if _internal_config is not None and len(_internal_config) != 0:
raise ValueError("When connecting to an existing cluster, "
"_internal_config must not be provided.")

View file

@ -39,6 +39,7 @@ define_java_module(
":io_ray_ray_streaming-state",
":io_ray_ray_streaming-api",
"@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
"@ray_streaming_maven//:org_testng_testng",
@ -46,7 +47,12 @@ define_java_module(
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_streaming-state",
"//java:io_ray_ray_api",
"//java:io_ray_ray_runtime",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:com_google_code_gson_gson",
"@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
],
@ -129,8 +135,9 @@ define_java_module(
":io_ray_ray_streaming-api",
":io_ray_ray_streaming-runtime",
"@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:de_ruedigermoeller_fst",
"@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_aeonbits_owner_owner",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
@ -146,10 +153,12 @@ define_java_module(
"//java:io_ray_ray_api",
"//java:io_ray_ray_runtime",
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:com_google_protobuf_protobuf_java",
"@ray_streaming_maven//:de_ruedigermoeller_fst",
"@ray_streaming_maven//:org_aeonbits_owner_owner",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",

View file

@ -6,8 +6,11 @@ def gen_streaming_java_deps():
artifacts = [
"com.beust:jcommander:1.72",
"com.google.guava:guava:27.0.1-jre",
"com.google.code.findbugs:jsr305:3.0.2",
"com.google.code.gson:gson:2.8.5",
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.protobuf:protobuf-java:3.8.0",
"org.apache.commons:commons-lang3:3.4",
"de.ruedigermoeller:fst:2.57",
"org.aeonbits.owner:owner:1.0.10",
"org.slf4j:slf4j-api:1.7.12",
@ -19,10 +22,9 @@ def gen_streaming_java_deps():
"org.apache.commons:commons-lang3:3.3.2",
"org.msgpack:msgpack-core:0.8.20",
"org.testng:testng:6.9.10",
"org.mockito:mockito-all:1.10.19",
"org.powermock:powermock-module-testng:1.6.6",
"org.powermock:powermock-api-mockito:1.6.6",
"org.projectlombok:lombok:1.16.20",
"org.mockito:mockito-all:1.10.19",
"org.powermock:powermock-module-testng:1.6.6",
"org.powermock:powermock-api-mockito:1.6.6",
],
repositories = [
"https://repo1.maven.org/maven2/",

View file

@ -22,16 +22,36 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>streaming-state</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>27.0.1-jre</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>

View file

@ -22,6 +22,11 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>streaming-state</artifactId>

View file

@ -0,0 +1,129 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class ClusterStarter {
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
Preconditions.checkArgument(Ray.internal() == null);
RayConfig.reset();
if (!isLocal) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
} else {
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
}
if (!isCrossLanguage) {
Ray.init();
return;
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
LOG.info("Delete existing socket file {}", file);
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.joining(":"));
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
config.put("num_workers_per_process_java", "1");
// Start ray cluster.
List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions,
"--internal-config=" + new Gson().toJson(config)
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
public static synchronized void stopCluster(boolean isCrossLanguage) {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.clearProperty("ray.run-mode");
if (isCrossLanguage) {
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
LOG.info("Executing command: {}", String.join(" ", command));
try {
ProcessBuilder processBuilder = new ProcessBuilder(command)
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Process process = processBuilder.start();
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
if (!exit) {
process.destroyForcibly();
}
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
}

View file

@ -1,10 +1,12 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import io.ray.api.Ray;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.schedule.JobScheduler;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@ -13,11 +15,14 @@ import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Encapsulate the context information of a streaming Job.
*/
public class StreamingContext implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
private transient AtomicInteger idGenerator;
@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
this.jobGraph = jobGraphBuilder.build();
jobGraph.printJobGraph();
if (Ray.internal() == null) {
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
ClusterStarter.startCluster(false, true);
LOG.info("Created local cluster for job {}.", jobName);
} else {
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
LOG.info("Created multi process cluster for job {}.", jobName);
}
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
} else {
LOG.info("Reuse existing cluster.");
}
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
Iterator<JobScheduler> iterator = serviceLoader.iterator();
Preconditions.checkArgument(iterator.hasNext(),
@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
public void withConfig(Map<String, String> jobConfig) {
this.jobConfig = jobConfig;
}
public void stop() {
if (Ray.internal() != null) {
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
}
}
}

View file

@ -1,6 +1,7 @@
package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.FlatMapFunction;
@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
import io.ray.streaming.operator.impl.KeyByOperator;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.operator.impl.SinkOperator;
import io.ray.streaming.python.stream.PythonDataStream;
/**
* Represents a stream of data.
*
* This class defines all the streaming operations.
* <p>This class defines all the streaming operations.
*
* @param <T> Type of data in the stream.
*/
public class DataStream<T> extends Stream<T> {
public class DataStream<T> extends Stream<DataStream<T>, T> {
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
super(streamingContext, streamOperator);
}
public DataStream(DataStream input, StreamOperator streamOperator) {
public DataStream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
super(streamingContext, streamOperator, partition);
}
public <R> DataStream(DataStream<R> input, StreamOperator streamOperator) {
super(input, streamOperator);
}
public <R> DataStream(DataStream<R> input,
StreamOperator streamOperator,
Partition<T> partition) {
super(input, streamOperator, partition);
}
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public DataStream(PythonDataStream referencedStream) {
super(referencedStream);
}
/**
* Apply a map function to this stream.
*
@ -41,7 +62,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream.
*/
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
return new DataStream<>(this, new MapOperator(mapFunction));
return new DataStream<>(this, new MapOperator<>(mapFunction));
}
/**
@ -52,11 +73,11 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream
*/
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
return new DataStream(this, new FlatMapOperator(flatMapFunction));
return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
}
public DataStream<T> filter(FilterFunction<T> filterFunction) {
return new DataStream<T>(this, new FilterOperator(filterFunction));
return new DataStream<>(this, new FilterOperator<>(filterFunction));
}
/**
@ -66,7 +87,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new UnionStream.
*/
public UnionStream<T> union(DataStream<T> other) {
return new UnionStream(this, null, other);
return new UnionStream<>(this, null, other);
}
/**
@ -93,7 +114,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new StreamSink.
*/
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
return new DataStreamSink<>(this, new SinkOperator(sinkFunction));
return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
}
/**
@ -104,7 +125,8 @@ public class DataStream<T> extends Stream<T> {
* @return A new KeyDataStream.
*/
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
return new KeyDataStream<>(this, new KeyByOperator(keyFunction));
checkPartitionCall();
return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
}
/**
@ -113,8 +135,8 @@ public class DataStream<T> extends Stream<T> {
* @return This stream.
*/
public DataStream<T> broadcast() {
this.partition = new BroadcastPartition<>();
return this;
checkPartitionCall();
return setPartition(new BroadcastPartition<>());
}
/**
@ -124,19 +146,32 @@ public class DataStream<T> extends Stream<T> {
* @return This stream.
*/
public DataStream<T> partitionBy(Partition<T> partition) {
this.partition = partition;
return this;
checkPartitionCall();
return setPartition(partition);
}
/**
* Set parallelism to current transformation.
*
* @param parallelism The parallelism to set.
* @return This stream.
* If parent stream is a python stream, we can't call partition related methods
* in the java stream.
*/
public DataStream<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
private void checkPartitionCall() {
if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
throw new RuntimeException("Partition related methods can't be called on a " +
"java stream if parent stream is a python stream.");
}
}
/**
* Convert this stream as a python stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonDataStream asPythonStream() {
return new PythonDataStream(this);
}
@Override
public Language getLanguage() {
return Language.JAVA;
}
}

View file

@ -1,5 +1,6 @@
package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.operator.impl.SinkOperator;
/**
@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
*/
public class DataStreamSink<T> extends StreamSink<T> {
public DataStreamSink(DataStream<T> input, SinkOperator sinkOperator) {
public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
super(input, sinkOperator);
this.streamingContext.addSink(this);
getStreamingContext().addSink(this);
}
public DataStreamSink<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
@Override
public Language getLanguage() {
return Language.JAVA;
}
}

View file

@ -14,27 +14,26 @@ import java.util.Collection;
*/
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
public DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction));
super.partition = new RoundRobinPartition<>();
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
}
public static <T> DataStreamSource<T> fromSource(
StreamingContext context, SourceFunction<T> sourceFunction) {
return new DataStreamSource<>(context, sourceFunction);
}
/**
* Build a DataStreamSource source from a collection.
*
* @param context Stream context.
* @param values A collection of values.
* @param <T> The type of source data.
* @param values A collection of values.
* @param <T> The type of source data.
* @return A DataStreamSource.
*/
public static <T> DataStreamSource<T> buildSource(
public static <T> DataStreamSource<T> fromCollection(
StreamingContext context, Collection<T> values) {
return new DataStreamSource(context, new CollectionSourceFunction(values));
return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
}
@Override
public DataStreamSource<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
}

View file

@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.AggregateFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.impl.ReduceOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
/**
* Represents a DataStream returned by a key-by operation.
@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
* @param <K> Type of the key.
* @param <T> Type of the data.
*/
@SuppressWarnings("unchecked")
public class KeyDataStream<K, T> extends DataStream<T> {
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
super(input, streamOperator);
this.partition = new KeyPartition();
super(input, streamOperator, (Partition<T>) new KeyPartition<K, T>());
}
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public KeyDataStream(PythonDataStream referencedStream) {
super(referencedStream);
}
/**
@ -41,8 +52,13 @@ public class KeyDataStream<K, T> extends DataStream<T> {
return new DataStream<>(this, null);
}
public KeyDataStream<K, T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
/**
* Convert this stream as a python stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonKeyDataStream asPythonStream() {
return new PythonKeyDataStream(this);
}
}

View file

@ -1,58 +1,99 @@
package io.ray.streaming.api.stream;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStream;
import java.io.Serializable;
/**
* Abstract base class of all stream types.
*
* @param <S> Type of stream class
* @param <T> Type of the data in the stream.
*/
public abstract class Stream<T> implements Serializable {
protected int id;
protected int parallelism = 1;
protected StreamOperator operator;
protected Stream<T> inputStream;
protected StreamingContext streamingContext;
protected Partition<T> partition;
public abstract class Stream<S extends Stream<S, T>, T>
implements Serializable {
private final int id;
private final StreamingContext streamingContext;
private final Stream inputStream;
private final StreamOperator operator;
private int parallelism = 1;
private Partition<T> partition;
private Stream originalStream;
@SuppressWarnings("unchecked")
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
this(streamingContext, null, streamOperator,
selectPartition(streamOperator));
}
public Stream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
this(streamingContext, null, streamOperator, partition);
}
public Stream(Stream inputStream, StreamOperator streamOperator) {
this(inputStream.getStreamingContext(), inputStream, streamOperator,
selectPartition(streamOperator));
}
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
}
protected Stream(StreamingContext streamingContext,
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
this.streamingContext = streamingContext;
this.inputStream = inputStream;
this.operator = streamOperator;
this.partition = partition;
this.id = streamingContext.generateId();
if (streamOperator instanceof PythonOperator) {
this.partition = PythonPartition.RoundRobinPartition;
} else {
this.partition = new RoundRobinPartition<>();
if (inputStream != null) {
this.parallelism = inputStream.getParallelism();
}
}
public Stream(Stream<T> inputStream, StreamOperator streamOperator) {
this.inputStream = inputStream;
this.parallelism = inputStream.getParallelism();
this.streamingContext = this.inputStream.getStreamingContext();
this.operator = streamOperator;
this.id = streamingContext.generateId();
this.partition = selectPartition();
/**
* Create a proxy stream of original stream.
* Changes in new stream will be reflected in original stream and vice versa
*/
protected Stream(Stream originalStream) {
this.originalStream = originalStream;
this.id = originalStream.getId();
this.streamingContext = originalStream.getStreamingContext();
this.inputStream = originalStream.getInputStream();
this.operator = originalStream.getOperator();
}
@SuppressWarnings("unchecked")
private Partition<T> selectPartition() {
if (inputStream instanceof PythonStream) {
return PythonPartition.RoundRobinPartition;
} else {
return new RoundRobinPartition<>();
private static <T> Partition<T> selectPartition(Operator operator) {
switch (operator.getLanguage()) {
case PYTHON:
return (Partition<T>) PythonPartition.RoundRobinPartition;
case JAVA:
return new RoundRobinPartition<>();
default:
throw new UnsupportedOperationException(
"Unsupported language " + operator.getLanguage());
}
}
public Stream<T> getInputStream() {
public int getId() {
return id;
}
public StreamingContext getStreamingContext() {
return streamingContext;
}
public Stream getInputStream() {
return inputStream;
}
@ -60,32 +101,47 @@ public abstract class Stream<T> implements Serializable {
return operator;
}
public void setOperator(StreamOperator operator) {
this.operator = operator;
}
public StreamingContext getStreamingContext() {
return streamingContext;
@SuppressWarnings("unchecked")
private S self() {
return (S) this;
}
public int getParallelism() {
return parallelism;
return originalStream != null ? originalStream.getParallelism() : parallelism;
}
public Stream<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
public int getId() {
return id;
public S setParallelism(int parallelism) {
if (originalStream != null) {
originalStream.setParallelism(parallelism);
} else {
this.parallelism = parallelism;
}
return self();
}
@SuppressWarnings("unchecked")
public Partition<T> getPartition() {
return partition;
return originalStream != null ? originalStream.getPartition() : partition;
}
public void setPartition(Partition<T> partition) {
this.partition = partition;
@SuppressWarnings("unchecked")
protected S setPartition(Partition<T> partition) {
if (originalStream != null) {
originalStream.setPartition(partition);
} else {
this.partition = partition;
}
return self();
}
public boolean isProxyStream() {
return originalStream != null;
}
public Stream getOriginalStream() {
Preconditions.checkArgument(isProxyStream());
return originalStream;
}
public abstract Language getLanguage();
}

View file

@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
*
* @param <T> Type of the input data of this sink.
*/
public class StreamSink<T> extends Stream<T> {
public StreamSink(Stream<T> inputStream, StreamOperator streamOperator) {
public abstract class StreamSink<T> extends Stream<StreamSink<T>, T> {
public StreamSink(Stream inputStream, StreamOperator streamOperator) {
super(inputStream, streamOperator);
}
}

View file

@ -11,15 +11,15 @@ import java.util.List;
*/
public class UnionStream<T> extends DataStream<T> {
private List<DataStream> unionStreams;
private List<DataStream<T>> unionStreams;
public UnionStream(DataStream input, StreamOperator streamOperator, DataStream<T> other) {
public UnionStream(DataStream<T> input, StreamOperator streamOperator, DataStream<T> other) {
super(input, streamOperator);
this.unionStreams = new ArrayList<>();
this.unionStreams.add(other);
}
public List<DataStream> getUnionStreams() {
public List<DataStream<T>> getUnionStreams() {
return unionStreams;
}
}

View file

@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
}
}
public boolean isCrossLanguageGraph() {
Language language = jobVertexList.get(0).getLanguage();
for (JobVertex jobVertex : jobVertexList) {
if (jobVertex.getLanguage() != language) {
return true;
}
}
return false;
}
}

View file

@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink;
@ -10,8 +11,11 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class JobGraphBuilder {
private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
private JobGraph jobGraph;
@ -41,12 +45,19 @@ public class JobGraphBuilder {
}
private void processStream(Stream stream) {
while (stream.isProxyStream()) {
// Proxy stream and original stream are the same logical stream, both refer to the
// same data flow transformation. We should skip proxy stream to avoid applying same
// transformation multiple times.
LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
stream = stream.getOriginalStream();
}
StreamOperator streamOperator = stream.getOperator();
Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
"Reference stream should be skipped.");
int vertexId = stream.getId();
int parallelism = stream.getParallelism();
StreamOperator streamOperator = stream.getOperator();
JobVertex jobVertex = null;
JobVertex jobVertex;
if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
Stream parentStream = stream.getInputStream();

View file

@ -1,6 +1,8 @@
package io.ray.streaming.message;
import java.util.Objects;
public class KeyRecord<K, T> extends Record<T> {
private K key;
@ -17,4 +19,24 @@ public class KeyRecord<K, T> extends Record<T> {
public void setKey(K key) {
this.key = key;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
KeyRecord<?, ?> keyRecord = (KeyRecord<?, ?>) o;
return Objects.equals(key, keyRecord.key);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), key);
}
}

View file

@ -1,64 +0,0 @@
package io.ray.streaming.message;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.List;
public class Message implements Serializable {
private int taskId;
private long batchId;
private String stream;
private List<Record> recordList;
public Message(int taskId, long batchId, String stream, List<Record> recordList) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = recordList;
}
public Message(int taskId, long batchId, String stream, Record record) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = Lists.newArrayList(record);
}
public int getTaskId() {
return taskId;
}
public void setTaskId(int taskId) {
this.taskId = taskId;
}
public long getBatchId() {
return batchId;
}
public void setBatchId(long batchId) {
this.batchId = batchId;
}
public String getStream() {
return stream;
}
public void setStream(String stream) {
this.stream = stream;
}
public List<Record> getRecordList() {
return recordList;
}
public void setRecordList(List<Record> recordList) {
this.recordList = recordList;
}
public Record getRecord(int index) {
return recordList.get(0);
}
}

View file

@ -1,6 +1,7 @@
package io.ray.streaming.message;
import java.io.Serializable;
import java.util.Objects;
public class Record<T> implements Serializable {
@ -27,6 +28,24 @@ public class Record<T> implements Serializable {
this.stream = stream;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Record<?> record = (Record<?>) o;
return Objects.equals(stream, record.stream) &&
Objects.equals(value, record.value);
}
@Override
public int hashCode() {
return Objects.hash(stream, value);
}
@Override
public String toString() {
return value.toString();

View file

@ -1,6 +1,8 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.function.Function;
import org.apache.commons.lang3.StringUtils;
/**
* Represents a user defined python function.
@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
*
* <p>If the python data stream api is invoked from python, `function` will be not null.</p>
* <p>If the python data stream api is invoked from java, `moduleName` and
* `className`/`functionName` will be not null.</p>
* `functionName` will be not null.</p>
* <p>
* TODO serialize to bytes using protobuf
*/
public class PythonFunction implements Function {
public enum FunctionInterface {
@ -38,23 +39,43 @@ public class PythonFunction implements Function {
}
}
private byte[] function;
private String moduleName;
private String className;
private String functionName;
// null if this function is constructed from moduleName/functionName.
private final byte[] function;
// null if this function is constructed from serialized python function.
private final String moduleName;
// null if this function is constructed from serialized python function.
private final String functionName;
/**
* FunctionInterface can be used to validate python function,
* and look up operator class from FunctionInterface.
*/
private String functionInterface;
private PythonFunction(byte[] function,
String moduleName,
String className,
String functionName) {
/**
* Create a {@link PythonFunction} from a serialized streaming python function.
*
* @param function serialized streaming python function from python driver.
*/
public PythonFunction(byte[] function) {
Preconditions.checkNotNull(function);
this.function = function;
this.moduleName = null;
this.functionName = null;
}
/**
* Create a {@link PythonFunction} from a moduleName and streaming function name.
*
* @param moduleName module name of streaming function.
* @param functionName function name of streaming function. {@code functionName} is the name
* of a python function, or class name of subclass of `ray.streaming.function.`
*/
public PythonFunction(String moduleName,
String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.function = null;
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@ -70,10 +91,6 @@ public class PythonFunction implements Function {
return moduleName;
}
public String getClassName() {
return className;
}
public String getFunctionName() {
return functionName;
}
@ -82,34 +99,4 @@ public class PythonFunction implements Function {
return functionInterface;
}
/**
* Create a {@link PythonFunction} using python serialized function
*
* @param function serialized python function sent from python driver
*/
public static PythonFunction fromFunction(byte[] function) {
return new PythonFunction(function, null, null, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>className</code>.
*
* @param moduleName python module name
* @param className python class name
*/
public static PythonFunction fromClassName(String moduleName, String className) {
return new PythonFunction(null, moduleName, className, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>functionName</code>.
*
* @param moduleName python module name
* @param functionName python function name
*/
public static PythonFunction fromFunctionName(String moduleName, String functionName) {
return new PythonFunction(null, moduleName, null, functionName);
}
}

View file

@ -1,6 +1,8 @@
package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.partition.Partition;
import org.apache.commons.lang3.StringUtils;
/**
* Represents a python partition function.
@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
* If this object is constructed from moduleName and className/functionName,
* python worker will use `importlib` to load python partition function.
* <p>
* TODO serialize to bytes using protobuf
*/
public class PythonPartition implements Partition {
public class PythonPartition implements Partition<Object> {
public static final PythonPartition BroadcastPartition = new PythonPartition(
"ray.streaming.partition", "BroadcastPartition", null);
"ray.streaming.partition", "BroadcastPartition");
public static final PythonPartition KeyPartition = new PythonPartition(
"ray.streaming.partition", "KeyPartition", null);
"ray.streaming.partition", "KeyPartition");
public static final PythonPartition RoundRobinPartition = new PythonPartition(
"ray.streaming.partition", "RoundRobinPartition", null);
"ray.streaming.partition", "RoundRobinPartition");
private byte[] partition;
private String moduleName;
private String className;
private String functionName;
public PythonPartition(byte[] partition) {
Preconditions.checkNotNull(partition);
this.partition = partition;
}
public PythonPartition(String moduleName, String className, String functionName) {
/**
* Create a python partition from a moduleName and partition function name
* @param moduleName module name of python partition
* @param functionName function/class name of the partition function.
*/
public PythonPartition(String moduleName, String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.moduleName = moduleName;
this.className = className;
this.functionName = functionName;
}
@ -53,10 +60,6 @@ public class PythonPartition implements Partition {
return moduleName;
}
public String getClassName() {
return className;
}
public String getFunctionName() {
return functionName;
}

View file

@ -1,6 +1,9 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition;
/**
* Represents a stream of data whose transformations will be executed in python.
*/
public class PythonDataStream extends Stream implements PythonStream {
public class PythonDataStream extends Stream<PythonDataStream, Object> implements PythonStream {
protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator) {
super(streamingContext, pythonOperator);
}
protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator,
Partition<Object> partition) {
super(streamingContext, pythonOperator, partition);
}
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator);
}
protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) {
super(inputStream, pythonOperator);
public PythonDataStream(PythonDataStream input,
PythonOperator pythonOperator,
Partition<Object> partition) {
super(input, pythonOperator, partition);
}
/**
* Create a python stream that reference passed java stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream map(String moduleName, String funcName) {
return map(new PythonFunction(moduleName, funcName));
}
/**
@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonDataStream flatMap(String moduleName, String funcName) {
return flatMap(new PythonFunction(moduleName, funcName));
}
/**
* Apply a flat-map function to this stream.
*
@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonDataStream filter(String moduleName, String funcName) {
return filter(new PythonFunction(moduleName, funcName));
}
/**
* Apply a filter function to this stream.
*
@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonStreamSink sink(String moduleName, String funcName) {
return sink(new PythonFunction(moduleName, funcName));
}
/**
* Apply a sink function and get a StreamSink.
*
@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonStreamSink(this, new PythonOperator(func));
}
public PythonKeyDataStream keyBy(String moduleName, String funcName) {
return keyBy(new PythonFunction(moduleName, funcName));
}
/**
* Apply a key-by function to this stream.
*
@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return A new KeyDataStream.
*/
public PythonKeyDataStream keyBy(PythonFunction func) {
checkPartitionCall();
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
return new PythonKeyDataStream(this, new PythonOperator(func));
}
@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream.
*/
public PythonDataStream broadcast() {
this.partition = PythonPartition.BroadcastPartition;
return this;
checkPartitionCall();
return setPartition(PythonPartition.BroadcastPartition);
}
/**
@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream.
*/
public PythonDataStream partitionBy(PythonPartition partition) {
this.partition = partition;
return this;
checkPartitionCall();
return setPartition(partition);
}
/**
* Set parallelism to current transformation.
*
* @param parallelism The parallelism to set.
* @return This stream.
* If parent stream is a python stream, we can't call partition related methods
* in the java stream.
*/
public PythonDataStream setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
private void checkPartitionCall() {
if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) {
throw new RuntimeException("Partition related methods can't be called on a " +
"python stream if parent stream is a java stream.");
}
}
/**
* Convert this stream as a java stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public DataStream<Object> asJavaStream() {
return new DataStream<>(this);
}
@Override
public Language getLanguage() {
return Language.PYTHON;
}
}

View file

@ -1,5 +1,7 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.KeyDataStream;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator;
@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition;
/**
* Represents a python DataStream returned by a key-by operation.
*/
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
@SuppressWarnings("unchecked")
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator);
this.partition = PythonPartition.KeyPartition;
super(input, pythonOperator, PythonPartition.KeyPartition);
}
/**
* Create a python stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonKeyDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream reduce(String moduleName, String funcName) {
return reduce(new PythonFunction(moduleName, funcName));
}
/**
@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
return new PythonDataStream(this, new PythonOperator(func));
}
public PythonKeyDataStream setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
/**
* Convert this stream as a java stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public KeyDataStream<Object, Object> asJavaStream() {
return new KeyDataStream(this);
}
}

View file

@ -1,5 +1,6 @@
package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.python.PythonOperator;
@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator;
public class PythonStreamSink extends StreamSink implements PythonStream {
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
super(input, sinkOperator);
this.streamingContext.addSink(this);
getStreamingContext().addSink(this);
}
public PythonStreamSink setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
@Override
public Language getLanguage() {
return Language.PYTHON;
}
}

View file

@ -13,17 +13,12 @@ import io.ray.streaming.python.PythonPartition;
public class PythonStreamSource extends PythonDataStream implements StreamSource {
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
super(streamingContext, new PythonOperator(sourceFunction));
super.partition = PythonPartition.RoundRobinPartition;
}
public PythonStreamSource setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
super(streamingContext, new PythonOperator(sourceFunction),
PythonPartition.RoundRobinPartition);
}
public static PythonStreamSource from(StreamingContext streamingContext,
PythonFunction sourceFunction) {
PythonFunction sourceFunction) {
sourceFunction.setFunctionInterface(FunctionInterface.SOURCE_FUNCTION);
return new PythonStreamSource(streamingContext, sourceFunction);
}

View file

@ -21,7 +21,6 @@ public class Config {
public static final String CHANNEL_TYPE = "channel_type";
public static final String MEMORY_CHANNEL = "memory_channel";
public static final String NATIVE_CHANNEL = "native_channel";
public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL;
public static final String CHANNEL_SIZE = "channel_size";
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
public static final String IS_RECREATE = "streaming.is_recreate";

View file

@ -0,0 +1,40 @@
package io.ray.streaming.api.stream;
import static org.testng.Assert.assertEquals;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
import org.testng.annotations.Test;
@SuppressWarnings("unchecked")
public class StreamTest {
@Test
public void testReferencedDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
PythonDataStream pythonDataStream = dataStream.asPythonStream();
DataStream javaStream = pythonDataStream.asJavaStream();
assertEquals(dataStream.getId(), pythonDataStream.getId());
assertEquals(dataStream.getId(), javaStream.getId());
javaStream.setParallelism(10);
assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism());
assertEquals(dataStream.getParallelism(), javaStream.getParallelism());
}
@Test
public void testReferencedKeyDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
KeyDataStream keyDataStream = dataStream.keyBy(value -> null);
PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream();
KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream();
assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId());
assertEquals(keyDataStream.getId(), javaKeyDataStream.getId());
javaKeyDataStream.setParallelism(10);
assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism());
assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism());
}
}

View file

@ -38,7 +38,7 @@ public class JobGraphBuilderTest {
public JobGraph buildDataSyncJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
@ -73,7 +73,7 @@ public class JobGraphBuilderTest {
public JobGraph buildKeyByJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("1", "2", "3", "4"));
StreamSink streamSink = dataStream.keyBy(x -> x)
.sink(x -> LOG.info(x));

View file

@ -36,6 +36,11 @@
<artifactId>flatbuffers-java</artifactId>
<version>1.9.0.1</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@ -56,6 +61,11 @@
<artifactId>owner</artifactId>
<version>1.0.10</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
@ -71,11 +81,6 @@
<artifactId>powermock-api-mockito</artifactId>
<version>1.6.6</version>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-core</artifactId>
<version>1.6.6</version>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-testng</artifactId>

View file

@ -1,9 +1,14 @@
package io.ray.streaming.runtime.core.collector;
import io.ray.runtime.serializer.Serializer;
import io.ray.api.BaseActor;
import io.ray.api.RayPyActor;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.message.Record;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.ChannelID;
import io.ray.streaming.runtime.transfer.DataWriter;
import java.nio.ByteBuffer;
@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory;
public class OutputCollector implements Collector<Record> {
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
private Partition partition;
private DataWriter writer;
private ChannelID[] outputQueues;
private final DataWriter writer;
private final ChannelID[] outputQueues;
private final Collection<BaseActor> targetActors;
private final Language[] targetLanguages;
private final Partition partition;
private final Serializer javaSerializer = new JavaSerializer();
private final Serializer crossLangSerializer = new CrossLangSerializer();
public OutputCollector(Collection<String> outputQueueIds,
DataWriter writer,
public OutputCollector(DataWriter writer,
Collection<String> outputQueueIds,
Collection<BaseActor> targetActors,
Partition partition) {
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
this.writer = writer;
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
this.targetActors = targetActors;
this.targetLanguages = targetActors.stream()
.map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA)
.toArray(Language[]::new);
this.partition = partition;
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
outputQueueIds, this.partition);
@ -31,9 +45,32 @@ public class OutputCollector implements Collector<Record> {
@Override
public void collect(Record record) {
int[] partitions = this.partition.partition(record, outputQueues.length);
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft());
ByteBuffer javaBuffer = null;
ByteBuffer crossLangBuffer = null;
for (int partition : partitions) {
writer.write(outputQueues[partition], msgBuffer);
if (targetLanguages[partition] == Language.JAVA) {
// avoid repeated serialization
if (javaBuffer == null) {
byte[] bytes = javaSerializer.serialize(record);
javaBuffer = ByteBuffer.allocate(1 + bytes.length);
javaBuffer.put(Serializer.JAVA_TYPE_ID);
// TODO(chaokunyang) remove copy
javaBuffer.put(bytes);
javaBuffer.flip();
}
writer.write(outputQueues[partition], javaBuffer.duplicate());
} else {
// avoid repeated serialization
if (crossLangBuffer == null) {
byte[] bytes = crossLangSerializer.serialize(record);
crossLangBuffer = ByteBuffer.allocate(1 + bytes.length);
crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID);
// TODO(chaokunyang) remove copy
crossLangBuffer.put(bytes);
crossLangBuffer.flip();
}
writer.write(outputQueues[partition], crossLangBuffer.duplicate());
}
}
}

View file

@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionTask;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.generated.Streaming;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays;
public class GraphPbBuilder {
@ -74,11 +75,10 @@ public class GraphPbBuilder {
private byte[] serializeFunction(Function function) {
if (function instanceof PythonFunction) {
PythonFunction pyFunc = (PythonFunction) function;
// function_bytes, module_name, class_name, function_name, function_interface
// function_bytes, module_name, function_name, function_interface
return serializer.serialize(Arrays.asList(
pyFunc.getFunction(), pyFunc.getModuleName(),
pyFunc.getClassName(), pyFunc.getFunctionName(),
pyFunc.getFunctionInterface()
pyFunc.getFunctionName(), pyFunc.getFunctionInterface()
));
} else {
return new byte[0];
@ -88,10 +88,10 @@ public class GraphPbBuilder {
private byte[] serializePartition(Partition partition) {
if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition;
// partition_bytes, module_name, class_name, function_name
// partition_bytes, module_name, function_name
return serializer.serialize(Arrays.asList(
pythonPartition.getPartition(), pythonPartition.getModuleName(),
pythonPartition.getClassName(), pythonPartition.getFunctionName()
pythonPartition.getFunctionName()
));
} else {
return new byte[0];

View file

@ -1,16 +1,21 @@
package io.ray.streaming.runtime.python;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Primitives;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStreamSource;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import io.ray.streaming.runtime.util.ReflectionUtils;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.msgpack.core.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -68,7 +73,7 @@ public class PythonGateway {
Preconditions.checkNotNull(streamingContext);
try {
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
streamingContext, PythonFunction.fromFunction(pySourceFunc));
streamingContext, new PythonFunction(pySourceFunc));
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
return serializer.serialize(getReferenceId(pythonStreamSource));
} catch (Exception e) {
@ -84,7 +89,7 @@ public class PythonGateway {
}
public byte[] createPyFunc(byte[] pyFunc) {
PythonFunction function = PythonFunction.fromFunction(pyFunc);
PythonFunction function = new PythonFunction(pyFunc);
referenceMap.put(getReferenceId(function), function);
return serializer.serialize(getReferenceId(function));
}
@ -98,15 +103,21 @@ public class PythonGateway {
public byte[] callFunction(byte[] paramsBytes) {
try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params);
params = processParameters(params);
LOG.info("callFunction params {}", params);
String className = (String) params.get(0);
String funcName = (String) params.get(1);
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
Method method = ReflectionUtils.findMethod(clz, funcName);
Class[] paramsTypes = params.subList(2, params.size()).stream()
.map(Object::getClass).toArray(Class[]::new);
Method method = findMethod(clz, funcName, paramsTypes);
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
if (returnReference(result)) {
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
} else {
return serializer.serialize(result);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
@ -115,31 +126,78 @@ public class PythonGateway {
public byte[] callMethod(byte[] paramsBytes) {
try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params);
params = processParameters(params);
LOG.info("callMethod params {}", params);
Object obj = params.get(0);
String methodName = (String) params.get(1);
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
Class<?> clz = obj.getClass();
Class[] paramsTypes = params.subList(2, params.size()).stream()
.map(Object::getClass).toArray(Class[]::new);
Method method = findMethod(clz, methodName, paramsTypes);
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
if (returnReference(result)) {
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
} else {
return serializer.serialize(result);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private List<Object> processReferenceParameters(List<Object> params) {
return params.stream().map(this::processReferenceParameter)
private static Method findMethod(Class<?> cls, String methodName, Class[] paramsTypes) {
List<Method> methods = ReflectionUtils.findMethods(cls, methodName);
if (methods.size() == 1) {
return methods.get(0);
}
// Convert all params types to primitive types if it's boxed type
Class[] unwrappedTypes = Arrays.stream(paramsTypes)
.map((Function<Class, Class>) Primitives::unwrap)
.toArray(Class[]::new);
Optional<Method> any = methods.stream()
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
.findAny();
Preconditions.checkArgument(any.isPresent(),
String.format("Method %s with type %s doesn't exist on class %s",
methodName, Arrays.toString(paramsTypes), cls));
return any.get();
}
private static boolean returnReference(Object value) {
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
}
public byte[] newInstance(byte[] classNameBytes) {
String className = (String) serializer.deserialize(classNameBytes);
try {
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
Object instance = clz.newInstance();
referenceMap.put(getReferenceId(instance), instance);
return serializer.serialize(getReferenceId(instance));
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
throw new IllegalArgumentException(
String.format("Create instance for class %s failed", className), e);
}
}
private List<Object> processParameters(List<Object> params) {
return params.stream().map(this::processParameter)
.collect(Collectors.toList());
}
private Object processReferenceParameter(Object o) {
private Object processParameter(Object o) {
if (o instanceof String) {
Object value = referenceMap.get(o);
if (value != null) {
return value;
}
}
// Since python can't represent byte/short, we convert all Byte/Short to Integer
if (o instanceof Byte || o instanceof Short) {
return ((Number) o).intValue();
}
return o;
}

View file

@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler {
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
this.jobConfig = jobConfig;
this.jobGraph = jobGraph;
if (Ray.internal() == null) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
Ray.init();
}
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
boolean hasPythonNode = executionNodes.stream()
.allMatch(node -> node.getLanguage() == Language.PYTHON);
.anyMatch(node -> node.getLanguage() == Language.PYTHON);
RemoteCall.ExecutionGraph executionGraphPb = null;
if (hasPythonNode) {
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);

View file

@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.RayPyActor;
import io.ray.api.function.PyActorClass;
import io.ray.streaming.jobgraph.JobEdge;
import io.ray.streaming.jobgraph.JobGraph;
@ -15,8 +17,11 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TaskAssignerImpl implements TaskAssigner {
private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class);
/**
* Assign an optimized logical plan to execution graph.
@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner {
private BaseActor createWorker(JobVertex jobVertex) {
switch (jobVertex.getLanguage()) {
case PYTHON:
return Ray.createActor(
case PYTHON: {
RayPyActor worker = Ray.createActor(
new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
case JAVA:
return Ray.createActor(JobWorker::new);
LOG.info("Created python worker {}", worker);
return worker;
}
case JAVA: {
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
LOG.info("Created java worker {}", worker);
return worker;
}
default:
throw new UnsupportedOperationException(
"Unsupported language " + jobVertex.getLanguage());

View file

@ -0,0 +1,62 @@
package io.ray.streaming.runtime.serialization;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import java.util.Arrays;
import java.util.List;
/**
* A serializer for cross-lang serialization between java/python.
* TODO implements a more sophisticated serialization framework
*/
public class CrossLangSerializer implements Serializer {
private static final byte RECORD_TYPE_ID = 0;
private static final byte KEY_RECORD_TYPE_ID = 1;
private MsgPackSerializer msgPackSerializer = new MsgPackSerializer();
public byte[] serialize(Object object) {
Record record = (Record) object;
Object value = record.getValue();
Class<? extends Record> clz = record.getClass();
if (clz == Record.class) {
return msgPackSerializer.serialize(Arrays.asList(
RECORD_TYPE_ID, record.getStream(), value));
} else if (clz == KeyRecord.class) {
KeyRecord keyRecord = (KeyRecord) record;
Object key = keyRecord.getKey();
return msgPackSerializer.serialize(Arrays.asList(
KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value));
} else {
throw new UnsupportedOperationException(
String.format("Serialize %s is unsupported.", record));
}
}
@SuppressWarnings("unchecked")
public Record deserialize(byte[] bytes) {
List list = (List) msgPackSerializer.deserialize(bytes);
Byte typeId = (Byte) list.get(0);
switch (typeId) {
case RECORD_TYPE_ID: {
String stream = (String) list.get(1);
Object value = list.get(2);
Record record = new Record(value);
record.setStream(stream);
return record;
}
case KEY_RECORD_TYPE_ID: {
String stream = (String) list.get(1);
Object key = list.get(2);
Object value = list.get(3);
KeyRecord keyRecord = new KeyRecord(key, value);
keyRecord.setStream(stream);
return keyRecord;
}
default:
throw new UnsupportedOperationException("Unsupported type " + typeId);
}
}
}

View file

@ -0,0 +1,15 @@
package io.ray.streaming.runtime.serialization;
import io.ray.runtime.serializer.FstSerializer;
public class JavaSerializer implements Serializer {
@Override
public byte[] serialize(Object object) {
return FstSerializer.encode(object);
}
@Override
public <T> T deserialize(byte[] bytes) {
return FstSerializer.decode(bytes);
}
}

View file

@ -1,4 +1,4 @@
package io.ray.streaming.runtime.python;
package io.ray.streaming.runtime.serialization;
import com.google.common.io.BaseEncoding;
import java.util.ArrayList;
@ -31,6 +31,10 @@ public class MsgPackSerializer {
Class<?> clz = obj.getClass();
if (clz == Boolean.class) {
packer.packBoolean((Boolean) obj);
} else if (clz == Byte.class) {
packer.packByte((Byte) obj);
} else if (clz == Short.class) {
packer.packShort((Short) obj);
} else if (clz == Integer.class) {
packer.packInt((Integer) obj);
} else if (clz == Long.class) {
@ -84,7 +88,11 @@ public class MsgPackSerializer {
return value.asBooleanValue().getBoolean();
case INTEGER:
IntegerValue iv = value.asIntegerValue();
if (iv.isInIntRange()) {
if (iv.isInByteRange()) {
return iv.toByte();
} else if (iv.isInShortRange()) {
return iv.toShort();
} else if (iv.isInIntRange()) {
return iv.toInt();
} else if (iv.isInLongRange()) {
return iv.toLong();

View file

@ -0,0 +1,12 @@
package io.ray.streaming.runtime.serialization;
public interface Serializer {
byte CROSS_LANG_TYPE_ID = 0;
byte JAVA_TYPE_ID = 1;
byte PYTHON_TYPE_ID = 2;
byte[] serialize(Object object);
<T> T deserialize(byte[] bytes);
}

View file

@ -20,7 +20,7 @@ import java.util.Map;
*/
public class ChannelCreationParametersBuilder {
public class Parameter {
public static class Parameter {
private ActorId actorId;
private FunctionDescriptor asyncFunctionDescriptor;
@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder {
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
} else {
Preconditions.checkArgument(false, "Invalid actor type");
throw new IllegalArgumentException("Invalid actor type");
}
parameters.add(parameter);
}
@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder {
}
public String toString() {
String str = "";
StringBuilder str = new StringBuilder();
for (Parameter param : parameters) {
str += param.toString();
str.append(param.toString());
}
return str;
return str.toString();
}
}

View file

@ -40,7 +40,7 @@ public class DataReader {
}
long timerInterval = Long.parseLong(
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false;
if (Config.MEMORY_CHANNEL.equals(channelType)) {
isMock = true;

View file

@ -37,7 +37,7 @@ public class DataWriter {
Map<String, String> conf) {
Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size());
ChannelCreationParametersBuilder initialParameters =
ChannelCreationParametersBuilder initParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
@ -47,13 +47,14 @@ public class DataWriter {
for (int i = 0; i < outputChannels.size(); i++) {
msgIds[i] = 0;
}
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false;
if (Config.MEMORY_CHANNEL.equals(channelType)) {
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) {
isMock = true;
LOGGER.info("Using memory channel");
}
this.nativeWriterPtr = createWriterNative(
initialParameters,
initParameters,
outputChannelsBytes,
msgIds,
channelSize,

View file

@ -19,6 +19,7 @@ public class ReflectionUtils {
/**
* For covariant return type, return the most specific method.
*
* @return all methods named by {@code methodName},
*/
public static List<Method> findMethods(Class<?> cls, String methodName) {

View file

@ -1,5 +1,6 @@
package io.ray.streaming.runtime.worker;
import io.ray.api.Ray;
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext;
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask;
import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory;
*/
public class JobWorker implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
// special flag to indicate this actor not ready
private static final byte[] NOT_READY_FLAG = new byte[4];
static {
EnvUtil.loadNativeLibraries();
@ -53,12 +53,11 @@ public class JobWorker implements Serializable {
this.nodeType = executionNode.getNodeType();
this.streamProcessor = ProcessBuilder
.buildProcessor(executionNode.getStreamOperator());
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
.buildProcessor(executionNode.getStreamOperator());
LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.",
EnvUtil.getJvmPid(), taskId, streamProcessor);
String channelType = (String) this.config.getOrDefault(
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
if (channelType.equals(Config.NATIVE_CHANNEL)) {
if (!Ray.getRuntimeContext().isSingleProcess()) {
transferHandler = new TransferHandler();
}
task = createStreamTask();
@ -124,6 +123,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor
*/
public byte[] onReaderMessageSync(byte[] buffer) {
if (transferHandler == null) {
return NOT_READY_FLAG;
}
return transferHandler.onReaderMessageSync(buffer);
}
@ -139,6 +141,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor
*/
public byte[] onWriterMessageSync(byte[] buffer) {
if (transferHandler == null) {
return NOT_READY_FLAG;
}
return transferHandler.onWriterMessageSync(buffer);
}
}

View file

@ -1,7 +1,9 @@
package io.ray.streaming.runtime.worker.tasks;
import io.ray.runtime.serializer.Serializer;
import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.Message;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.util.Config;
@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask {
private volatile boolean running = true;
private volatile boolean stopped = false;
private long readTimeoutMillis;
private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
super(taskId, processor, streamWorker);
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
javaSerializer = new JavaSerializer();
crossLangSerializer = new CrossLangSerializer();
}
@Override
@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask {
while (running) {
Message item = reader.read(readTimeoutMillis);
if (item != null) {
byte[] bytes = new byte[item.body().remaining()];
byte[] bytes = new byte[item.body().remaining() - 1];
byte typeId = item.body().get();
item.body().get(bytes);
Object obj = Serializer.decode(bytes, Object.class);
Object obj;
if (typeId == Serializer.JAVA_TYPE_ID) {
obj = javaSerializer.deserialize(bytes);
} else {
obj = crossLangSerializer.deserialize(bytes);
}
processor.process(obj);
}
}

View file

@ -26,7 +26,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class StreamTask implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
protected int taskId;
@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable {
String queueSize = worker.getConfig()
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
queueConf.put(Config.CHANNEL_SIZE, queueSize);
String channelType = worker.getConfig()
.getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
String channelType = Ray.getRuntimeContext().isSingleProcess() ?
Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL;
queueConf.put(Config.CHANNEL_TYPE, channelType);
ExecutionGraph executionGraph = worker.getExecutionGraph();
@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable {
LOG.info("Create DataWriter succeed.");
writers.put(edge, writer);
Partition partition = edge.getPartition();
collectors.add(new OutputCollector(channelIDs, writer, partition));
collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition));
}
}
@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable {
reader = new DataReader(channelIDs, inputActors, queueConf);
}
RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(),
worker.getConfig(), executionNode.getParallelism());
RuntimeContext runtimeContext = new RayRuntimeContext(
worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
processor.open(collectors, runtimeContext);

View file

@ -24,11 +24,13 @@ public abstract class BaseUnitTest {
@BeforeMethod
public void testBegin(Method method) {
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>");
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>",
method.getDeclaringClass(), method.getName());
}
@AfterMethod
public void testEnd(Method method) {
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>");
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>",
method.getDeclaringClass(), method.getName());
}
}

View file

@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
public static JobGraph buildJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));

View file

@ -0,0 +1,56 @@
package io.ray.streaming.runtime.demo;
import io.ray.api.Ray;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.MapFunction;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.runtime.BaseUnitTest;
import java.io.Serializable;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
public class HybridStreamTest extends BaseUnitTest implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class);
public static class Mapper1 implements MapFunction<Object, Object> {
@Override
public Object map(Object value) {
LOG.info("HybridStreamTest Mapper1 {}", value);
return value.toString();
}
}
public static class Filter1 implements FilterFunction<Object> {
@Override
public boolean filter(Object value) throws Exception {
LOG.info("HybridStreamTest Filter1 {}", value);
return !value.toString().contains("b");
}
}
@Test
public void testHybridDataStream() throws InterruptedException {
Ray.shutdown();
StreamingContext context = StreamingContext.buildContext();
DataStreamSource<String> streamSource =
DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c"));
streamSource
.map(x -> x + x)
.asPythonStream()
.map("ray.streaming.tests.test_hybrid_stream", "map_func1")
.filter("ray.streaming.tests.test_hybrid_stream", "filter_func1")
.asJavaStream()
.sink(x -> System.out.println("HybridStreamTest: " + x));
context.execute("HybridStreamTestJob");
TimeUnit.SECONDS.sleep(3);
context.stop();
LOG.info("HybridStreamTest succeed");
}
}

View file

@ -1,6 +1,7 @@
package io.ray.streaming.runtime.demo;
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
@Test
public void testWordCount() {
Ray.shutdown();
StreamingContext streamingContext = StreamingContext.buildContext();
Map<String, String> config = new HashMap<>();
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config);
List<String> text = new ArrayList<>();
text.add("hello world eagle eagle eagle");
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
String[] records = value.split(" ");
@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
}
}
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
streamingContext.stop();
}
private static class WordAndCount implements Serializable {

View file

@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

View file

@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest {
public JobGraph buildDataSyncPlan() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));

View file

@ -0,0 +1,26 @@
package io.ray.streaming.runtime.serialization;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import org.apache.commons.lang3.builder.EqualsBuilder;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import org.testng.annotations.Test;
public class CrossLangSerializerTest {
@Test
@SuppressWarnings("unchecked")
public void testSerialize() {
CrossLangSerializer serializer = new CrossLangSerializer();
Record record = new Record("value");
record.setStream("stream1");
assertTrue(EqualsBuilder.reflectionEquals(record,
serializer.deserialize(serializer.serialize(record))));
KeyRecord keyRecord = new KeyRecord("key", "value");
keyRecord.setStream("stream2");
assertEquals(keyRecord,
serializer.deserialize(serializer.serialize(keyRecord)));
}
}

View file

@ -1,4 +1,7 @@
package io.ray.streaming.runtime.python;
package io.ray.streaming.runtime.serialization;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;
@ -6,25 +9,37 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
@SuppressWarnings("unchecked")
public class MsgPackSerializerTest {
@Test
public void testSerializeByte() {
MsgPackSerializer serializer = new MsgPackSerializer();
assertEquals(serializer.deserialize(
serializer.serialize((byte)1)), (byte)1);
}
@Test
public void testSerialize() {
MsgPackSerializer serializer = new MsgPackSerializer();
assertEquals(serializer.deserialize
(serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE);
assertEquals(serializer.deserialize(
serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE);
assertEquals(serializer.deserialize(
serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE);
Map map = new HashMap();
List list = new ArrayList<>();
list.add(null);
list.add(true);
list.add(1);
list.add(1.0d);
list.add("str");
map.put("k1", "value1");
map.put("k2", 2);
map.put("k2", new HashMap<>());
map.put("k3", list);
byte[] bytes = serializer.serialize(map);
Object o = serializer.deserialize(bytes);

View file

@ -5,6 +5,7 @@ import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.ActorCreationOptions.Builder;
import io.ray.runtime.config.RayConfig;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true");
// ray init
RayConfig.reset();
Ray.init();
}
@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000)
public void testWordCount() {
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true");
// ray init
Ray.init();
LOGGER.info("testWordCount");
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
System.getProperty("ray.run-mode"));
@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config);
List<String> text = new ArrayList<>();
text.add("hello world eagle eagle eagle");
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
String[] records = value.split(" ");
@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
serializeResultToFile(resultFile, wordCount);
});
streamingContext.execute("testWordCount");
streamingContext.execute("testSQWordCount");
Map<String, Integer> checkWordCount =
(Map<String, Integer>) deserializeResultFromFile(resultFile);

View file

@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on
echo "Running streaming tests."
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml ||
exit_code=$?
if [ -z ${exit_code+x} ]; then
exit_code=0
fi
echo "Streaming TestNG results"
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
cat /tmp/ray_streaming_java_test_output/testng-results.xml

View file

@ -1,10 +1,13 @@
import logging
import pickle
import typing
from abc import ABC, abstractmethod
from ray import Language
from ray.actor import ActorHandle
from ray.streaming import function
from ray.streaming import message
from ray.streaming import partition
from ray.streaming.runtime import serialization
from ray.streaming.runtime.transfer import ChannelID, DataWriter
logger = logging.getLogger(__name__)
@ -31,19 +34,46 @@ class CollectionCollector(Collector):
class OutputCollector(Collector):
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
target_actors: typing.List[ActorHandle],
partition_func: partition.Partition):
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._writer = writer
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._target_languages = []
for actor in target_actors:
if actor._ray_actor_language == Language.PYTHON:
self._target_languages.append(function.Language.PYTHON)
elif actor._ray_actor_language == Language.JAVA:
self._target_languages.append(function.Language.JAVA)
else:
raise Exception("Unsupported language {}"
.format(actor._ray_actor_language))
self._partition_func = partition_func
self.python_serializer = serialization.PythonSerializer()
self.cross_lang_serializer = serialization.CrossLangSerializer()
logger.info(
"Create OutputCollector, channel_ids {}, partition_func {}".format(
channel_ids, partition_func))
def collect(self, record):
partitions = self._partition_func.partition(record,
len(self._channel_ids))
serialized_message = pickle.dumps(record)
partitions = self._partition_func \
.partition(record, len(self._channel_ids))
python_buffer = None
cross_lang_buffer = None
for partition_index in partitions:
self._writer.write(self._channel_ids[partition_index],
serialized_message)
if self._target_languages[partition_index] == \
function.Language.PYTHON:
# avoid repeated serialization
if python_buffer is None:
python_buffer = self.python_serializer.serialize(record)
self._writer.write(
self._channel_ids[partition_index],
serialization._PYTHON_TYPE_ID + python_buffer)
else:
# avoid repeated serialization
if cross_lang_buffer is None:
cross_lang_buffer = self.cross_lang_serializer.serialize(
record)
self._writer.write(
self._channel_ids[partition_index],
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)

View file

@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from ray.streaming import function
from ray.streaming import partition
@ -19,7 +19,6 @@ class Stream(ABC):
self.streaming_context = input_stream.streaming_context
else:
self.streaming_context = streaming_context
self.parallelism = 1
def get_streaming_context(self):
return self.streaming_context
@ -29,7 +28,8 @@ class Stream(ABC):
Returns:
the parallelism of this transformation
"""
return self.parallelism
return self._gateway_client(). \
call_method(self._j_stream, "getParallelism")
def set_parallelism(self, parallelism: int):
"""Sets the parallelism of this transformation
@ -40,7 +40,6 @@ class Stream(ABC):
Returns:
self
"""
self.parallelism = parallelism
self._gateway_client(). \
call_method(self._j_stream, "setParallelism", parallelism)
return self
@ -60,6 +59,10 @@ class Stream(ABC):
return self._gateway_client(). \
call_method(self._j_stream, "getId")
@abstractmethod
def get_language(self):
pass
def _gateway_client(self):
return self.get_streaming_context()._gateway_client
@ -75,6 +78,9 @@ class DataStream(Stream):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.PYTHON
def map(self, func):
"""
Applies a Map transformation on a :class:`DataStream`.
@ -158,6 +164,7 @@ class DataStream(Stream):
Returns:
A KeyDataStream
"""
self._check_partition_call()
if not isinstance(func, function.KeyFunction):
func = function.SimpleKeyFunction(func)
j_func = self._gateway_client().create_py_func(
@ -175,6 +182,7 @@ class DataStream(Stream):
Returns:
The DataStream with broadcast partitioning set.
"""
self._check_partition_call()
self._gateway_client().call_method(self._j_stream, "broadcast")
return self
@ -191,6 +199,7 @@ class DataStream(Stream):
Returns:
The DataStream with specified partitioning set.
"""
self._check_partition_call()
if not isinstance(partition_func, partition.Partition):
partition_func = partition.SimplePartition(partition_func)
j_partition = self._gateway_client().create_py_func(
@ -199,6 +208,16 @@ class DataStream(Stream):
call_method(self._j_stream, "partitionBy", j_partition)
return self
def _check_partition_call(self):
"""
If parent stream is a java stream, we can't call partition related
methods in the python stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.JAVA:
raise Exception("Partition related methods can't be called on a "
"python stream if parent stream is a java stream.")
def sink(self, func):
"""
Create a StreamSink with the given sink.
@ -217,8 +236,97 @@ class DataStream(Stream):
call_method(self._j_stream, "sink", j_func)
return StreamSink(self, j_stream, func)
def as_java_stream(self):
"""
Convert this stream as a java JavaDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaDataStream(self, j_stream)
class KeyDataStream(Stream):
class JavaDataStream(Stream):
"""
Represents a stream of data which applies a transformation executed by
java. It's also a wrapper of java
`org.ray.streaming.api.stream.DataStream`
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
def map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.map"""
return JavaDataStream(self, self._unary_call("map", java_func_class))
def flat_map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
return JavaDataStream(self, self._unary_call("flatMap",
java_func_class))
def filter(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.filter"""
return JavaDataStream(self, self._unary_call("filter",
java_func_class))
def key_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
self._check_partition_call()
return JavaKeyDataStream(self,
self._unary_call("keyBy", java_func_class))
def broadcast(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("broadcast", java_func_class))
def partition_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("partitionBy", java_func_class))
def sink(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.sink"""
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python DataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return DataStream(self, j_stream)
def _check_partition_call(self):
"""
If parent stream is a python stream, we can't call partition related
methods in the java stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.PYTHON:
raise Exception("Partition related methods can't be called on a"
"java stream if parent stream is a python stream.")
def _unary_call(self, func_name, java_func_class):
j_func = self._gateway_client().new_instance(java_func_class)
j_stream = self._gateway_client(). \
call_method(self._j_stream, func_name, j_func)
return j_stream
class KeyDataStream(DataStream):
"""Represents a DataStream returned by a key-by operation.
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
"""
@ -251,6 +359,43 @@ class KeyDataStream(Stream):
call_method(self._j_stream, "reduce", j_func)
return DataStream(self, j_stream)
def as_java_stream(self):
"""
Convert this stream as a java KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaKeyDataStream(self, j_stream)
class JavaKeyDataStream(JavaDataStream):
"""
Represents a DataStream returned by a key-by operation in java.
Wrapper of org.ray.streaming.api.stream.KeyDataStream
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def reduce(self, java_func_class):
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
return JavaDataStream(self,
super()._unary_call("reduce", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return KeyDataStream(self, j_stream)
class StreamSource(DataStream):
"""Represents a source of the DataStream.
@ -261,9 +406,12 @@ class StreamSource(DataStream):
super().__init__(None, j_stream, streaming_context=streaming_context)
self.source_func = source_func
def get_language(self):
return function.Language.PYTHON
@staticmethod
def build_source(streaming_context, func):
"""Build a StreamSource source from a collection.
"""Build a StreamSource source from a source function.
Args:
streaming_context: Stream context
func: A instance of `SourceFunction`
@ -275,6 +423,34 @@ class StreamSource(DataStream):
return StreamSource(j_stream, streaming_context, func)
class JavaStreamSource(JavaDataStream):
"""Represents a source of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
"""
def __init__(self, j_stream, streaming_context):
super().__init__(None, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
@staticmethod
def build_source(streaming_context, java_source_func_class):
"""Build a java StreamSource source from a java source function.
Args:
streaming_context: Stream context
java_source_func_class: qualified class name of java SourceFunction
Returns:
A java StreamSource
"""
j_func = streaming_context._gateway_client() \
.new_instance(java_source_func_class)
j_stream = streaming_context._gateway_client() \
.call_function("org.ray.streaming.api.stream.DataStreamSource"
"fromSource", streaming_context._j_ctx, j_func)
return JavaStreamSource(j_stream, streaming_context)
class StreamSink(Stream):
"""Represents a sink of the DataStream.
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
@ -282,3 +458,18 @@ class StreamSink(Stream):
def __init__(self, input_stream, j_stream, func):
super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.PYTHON
class JavaStreamSink(Stream):
"""Represents a sink of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.StreamSink
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.JAVA

View file

@ -1,13 +1,19 @@
import enum
import importlib
import inspect
import sys
from abc import ABC, abstractmethod
import typing
from abc import ABC, abstractmethod
from ray import cloudpickle
from ray.streaming.runtime import gateway_client
class Language(enum.Enum):
JAVA = 0
PYTHON = 1
class Function(ABC):
"""The base interface for all user-defined functions."""
@ -60,6 +66,7 @@ class MapFunction(Function):
for each input element.
"""
@abstractmethod
def map(self, value):
pass
@ -70,6 +77,7 @@ class FlatMapFunction(Function):
transform them into zero, one, or more elements.
"""
@abstractmethod
def flat_map(self, value, collector):
"""Takes an element from the input data set and transforms it into zero,
one, or more elements.
@ -87,6 +95,7 @@ class FilterFunction(Function):
The predicate decides whether to keep the element, or to discard it.
"""
@abstractmethod
def filter(self, value):
"""The filter function that evaluates the predicate.
@ -106,6 +115,7 @@ class KeyFunction(Function):
deterministic key for that object.
"""
@abstractmethod
def key_by(self, value):
"""User-defined function that deterministically extracts the key from
an object.
@ -126,6 +136,7 @@ class ReduceFunction(Function):
them into one.
"""
@abstractmethod
def reduce(self, old_value, new_value):
"""
The core method of ReduceFunction, combining two values into one value
@ -145,6 +156,7 @@ class ReduceFunction(Function):
class SinkFunction(Function):
"""Interface for implementing user defined sink functionality."""
@abstractmethod
def sink(self, value):
"""Writes the given value to the sink. This function is called for
every record."""
@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
Returns:
a streaming function
"""
function_bytes, module_name, class_name, function_name, function_interface\
assert len(descriptor_func_bytes) > 0
function_bytes, module_name, function_name, function_interface\
= gateway_client.deserialize(descriptor_func_bytes)
if function_bytes:
return deserialize(function_bytes)
@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
assert function_interface
function_interface = getattr(sys.modules[__name__], function_interface)
mod = importlib.import_module(module_name)
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
assert issubclass(cls, function_interface)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
assert function_name
func = getattr(mod, function_name)
# If func is a python function, user function is a simple python
# function, which will be wrapped as a SimpleXXXFunction.
# If func is a python class, user function is a sub class
# of XXXFunction.
if inspect.isfunction(func):
simple_func_class = _get_simple_function_class(function_interface)
return simple_func_class(func)
else:
assert issubclass(func, function_interface)
return func()
def _get_simple_function_class(function_interface):

View file

@ -8,6 +8,14 @@ class Record:
def __repr__(self):
return "Record(%s)".format(self.value)
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.value) == (other.stream, other.value)
return False
def __hash__(self):
return hash((self.stream, self.value))
class KeyRecord(Record):
"""Data record in a keyed data stream"""
@ -15,3 +23,12 @@ class KeyRecord(Record):
def __init__(self, key, value):
super().__init__(value)
self.key = key
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.key, self.value) ==\
(other.stream, other.key, other.value)
return False
def __hash__(self):
return hash((self.stream, self.key, self.value))

View file

@ -1,4 +1,5 @@
import importlib
import inspect
from abc import ABC, abstractmethod
from ray import cloudpickle
@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
Returns:
partition function
"""
partition_bytes, module_name, class_name, function_name =\
assert len(descriptor_partition_bytes) > 0
partition_bytes, module_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes:
return deserialize(partition_bytes)
else:
assert module_name
mod = importlib.import_module(module_name)
# If class_name is not None, user partition is a sub class
# of Partition.
# If function_name is not None, user partition is a simple python
assert function_name
func = getattr(mod, function_name)
# If func is a python function, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
# If func is a python class, user partition is a sub class
# of Partition.
if inspect.isfunction(func):
return SimplePartition(func)
else:
assert issubclass(func, Partition)
return func()

View file

@ -55,6 +55,11 @@ class GatewayClient:
call = self._python_gateway_actor.callMethod.remote(java_params)
return deserialize(ray.get(call))
def new_instance(self, java_class_name):
call = self._python_gateway_actor.newInstance.remote(
serialize(java_class_name))
return deserialize(ray.get(call))
def serialize(obj) -> bytes:
"""Serialize a python object which can be deserialized by `PythonGateway`

View file

@ -53,7 +53,9 @@ class ExecutionEdge:
self.src_node_id = edge_pb.src_node_id
self.target_node_id = edge_pb.target_node_id
partition_bytes = edge_pb.partition
if language == Language.PYTHON:
# Sink node doesn't have partition function,
# so we only deserialize partition_bytes when it's not None or empty
if language == Language.PYTHON and partition_bytes:
self.partition = partition.load_partition(partition_bytes)

View file

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
import pickle
import msgpack
from ray.streaming import message
_RECORD_TYPE_ID = 0
_KEY_RECORD_TYPE_ID = 1
_CROSS_LANG_TYPE_ID = b"0"
_JAVA_TYPE_ID = b"1"
_PYTHON_TYPE_ID = b"2"
class Serializer(ABC):
@abstractmethod
def serialize(self, obj):
pass
@abstractmethod
def deserialize(self, serialized_bytes):
pass
class PythonSerializer(Serializer):
def serialize(self, obj):
return pickle.dumps(obj)
def deserialize(self, serialized_bytes):
return pickle.loads(serialized_bytes)
class CrossLangSerializer(Serializer):
"""Serialize stream element between java/python"""
def serialize(self, obj):
if type(obj) is message.Record:
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
elif type(obj) is message.KeyRecord:
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
else:
raise Exception("Unsupported value {}".format(obj))
return msgpack.packb(fields, use_bin_type=True)
def deserialize(self, data):
fields = msgpack.unpackb(data, raw=False)
if fields[0] == _RECORD_TYPE_ID:
stream, value = fields[1:]
record = message.Record(value)
record.stream = stream
return record
elif fields[0] == _KEY_RECORD_TYPE_ID:
stream, key, value = fields[1:]
key_record = message.KeyRecord(key, value)
key_record.stream = stream
return key_record
else:
raise Exception("Unsupported type id {}, type {}".format(
fields[0], type(fields[0])))

View file

@ -1,11 +1,13 @@
import logging
import pickle
import threading
from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime import serialization
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
@ -38,36 +40,40 @@ class StreamTask(ABC):
# writers
collectors = []
for edge in execution_node.output_edges:
output_actor_ids = {}
output_actors_map = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time())
output_actor_ids[channel_name] = target_actor
if len(output_actor_ids) > 0:
channel_ids = list(output_actor_ids.keys())
to_actor_ids = list(output_actor_ids.values())
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
logger.info("Create DataWriter succeed.")
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info(
"Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
OutputCollector(channel_ids, writer, edge.partition))
OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers
input_actor_ids = {}
input_actor_map = {}
for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time())
input_actor_ids[channel_name] = src_actor
if len(input_actor_ids) > 0:
channel_ids = list(input_actor_ids.keys())
from_actor_ids = list(input_actor_ids.values())
logger.info("Create DataReader, channels {}.".format(channel_ids))
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
input_actor_map[channel_name] = src_actor
if len(input_actor_map) > 0:
channel_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {}, input_actors {}."
.format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
msg = pickle.loads(msg_data)
type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True

View file

@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
wrap initial parameters needed by a streaming queue
"""
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
"([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
"([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
"([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
"([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(

View file

@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__)
# special flag to indicate this actor not ready
_NOT_READY_FLAG_ = b" " * 4
@ray.remote
class JobWorker(object):
@ -66,23 +69,31 @@ class JobWorker(object):
type(self.stream_processor))
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
"""Called by upstream queue writer to send data message to downstream
queue reader.
"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
"""Called by upstream queue writer to send control message to downstream
downstream queue reader.
"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
"""Called by downstream queue reader to send notify message to
upstream queue writer.
"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
"""Called by downstream queue reader to send control message to
upstream queue writer.
"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()

View file

@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
def test_load_function():
# function_bytes, module_name, class_name, function_name,
# function_bytes, module_name, function_name/class_name,
# function_interface
descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, None, "MapFunction"])
[None, __name__, MapFunc.__name__, "MapFunction"])
func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc

View file

@ -0,0 +1,70 @@
import json
import ray
from ray.streaming import StreamingContext
import subprocess
import os
def map_func1(x):
print("HybridStreamTest map_func1", x)
return str(x)
def filter_func1(x):
print("HybridStreamTest filter_func1", x)
return "b" not in x
def sink_func1(x):
print("HybridStreamTest sink_func1 value:", x)
def test_hybrid_stream():
subprocess.check_call(
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
current_dir = os.path.abspath(os.path.dirname(__file__))
jar_path = os.path.join(
current_dir,
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
jar_path = os.path.abspath(jar_path)
print("jar_path", jar_path)
java_worker_options = json.dumps(["-classpath", jar_path])
print("java_worker_options", java_worker_options)
assert not ray.is_initialized()
ray.init(
load_code_from_local=True,
include_java=True,
java_worker_options=java_worker_options,
_internal_config=json.dumps({
"num_workers_per_process_java": 1
}))
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
if os.path.exists(sink_file):
os.remove(sink_file)
def sink_func(x):
print("HybridStreamTest", x)
with open(sink_file, "a") as f:
f.write(str(x))
ctx = StreamingContext.Builder().build()
ctx.from_values("a", "b", "c") \
.as_java_stream() \
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
.as_python_stream() \
.sink(sink_func)
ctx.submit("HybridStreamTest")
import time
time.sleep(3)
ray.shutdown()
with open(sink_file, "r") as f:
result = f.read()
assert "a" in result
assert "b" not in result
assert "c" in result
if __name__ == "__main__":
test_hybrid_stream()

View file

@ -0,0 +1,13 @@
from ray.streaming.runtime.serialization import CrossLangSerializer
from ray.streaming.message import Record, KeyRecord
def test_serialize():
serializer = CrossLangSerializer()
record = Record("value")
record.stream = "stream1"
key_record = KeyRecord("key", "value")
key_record.stream = "stream2"
assert record == serializer.deserialize(serializer.serialize(record))
assert key_record == serializer.\
deserialize(serializer.serialize(key_record))

View file

@ -0,0 +1,31 @@
import ray
from ray.streaming import StreamingContext
def test_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
stream = ctx.from_values(1, 2, 3)
java_stream = stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert stream.get_id() == java_stream.get_id()
assert stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert stream.get_parallelism() == java_stream.get_parallelism()
assert stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()
def test_key_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
key_stream = ctx.from_values(
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
java_stream = key_stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert key_stream.get_id() == java_stream.get_id()
assert key_stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert key_stream.get_parallelism() == java_stream.get_parallelism()
assert key_stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()

View file

@ -32,7 +32,9 @@ def test_simple_word_count():
def sink_func(x):
with open(sink_file, "a") as f:
f.write("{}:{},".format(x[0], x[1]))
line = "{}:{},".format(x[0], x[1])
print("sink_func", line)
f.write(line)
ctx.from_values("a", "b", "c") \
.set_parallelism(1) \

View file

@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
return reinterpret_cast<jlong>(reader_client);
}
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {