[Streaming] Streaming Cross-Lang API (#7464)

This commit is contained in:
chaokunyang 2020-04-29 13:42:08 +08:00 committed by GitHub
parent 101255f782
commit 91f630f709
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
72 changed files with 1612 additions and 408 deletions

View file

@ -542,6 +542,7 @@ def init(address=None,
raylet_socket_name=None, raylet_socket_name=None,
temp_dir=None, temp_dir=None,
load_code_from_local=False, load_code_from_local=False,
java_worker_options=None,
use_pickle=True, use_pickle=True,
_internal_config=None, _internal_config=None,
lru_evict=False): lru_evict=False):
@ -651,6 +652,7 @@ def init(address=None,
conventional location, e.g., "/tmp/ray". conventional location, e.g., "/tmp/ray".
load_code_from_local: Whether code should be loaded from a local load_code_from_local: Whether code should be loaded from a local
module or from the GCS. module or from the GCS.
java_worker_options: Overwrite the options to start Java workers.
use_pickle: Deprecated. use_pickle: Deprecated.
_internal_config (str): JSON configuration for overriding _internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY. RayConfig defaults. For testing purposes ONLY.
@ -758,6 +760,7 @@ def init(address=None,
raylet_socket_name=raylet_socket_name, raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir, temp_dir=temp_dir,
load_code_from_local=load_code_from_local, load_code_from_local=load_code_from_local,
java_worker_options=java_worker_options,
_internal_config=_internal_config, _internal_config=_internal_config,
) )
# Start the Ray processes. We set shutdown_at_exit=False because we # Start the Ray processes. We set shutdown_at_exit=False because we
@ -808,6 +811,9 @@ def init(address=None,
if raylet_socket_name is not None: if raylet_socket_name is not None:
raise ValueError("When connecting to an existing cluster, " raise ValueError("When connecting to an existing cluster, "
"raylet_socket_name must not be provided.") "raylet_socket_name must not be provided.")
if java_worker_options is not None:
raise ValueError("When connecting to an existing cluster, "
"java_worker_options must not be provided.")
if _internal_config is not None and len(_internal_config) != 0: if _internal_config is not None and len(_internal_config) != 0:
raise ValueError("When connecting to an existing cluster, " raise ValueError("When connecting to an existing cluster, "
"_internal_config must not be provided.") "_internal_config must not be provided.")

View file

@ -39,6 +39,7 @@ define_java_module(
":io_ray_ray_streaming-state", ":io_ray_ray_streaming-state",
":io_ray_ray_streaming-api", ":io_ray_ray_streaming-api",
"@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
"@ray_streaming_maven//:org_testng_testng", "@ray_streaming_maven//:org_testng_testng",
@ -46,7 +47,12 @@ define_java_module(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":io_ray_ray_streaming-state", ":io_ray_ray_streaming-state",
"//java:io_ray_ray_api",
"//java:io_ray_ray_runtime",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:com_google_code_gson_gson",
"@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
], ],
@ -129,8 +135,9 @@ define_java_module(
":io_ray_ray_streaming-api", ":io_ray_ray_streaming-api",
":io_ray_ray_streaming-runtime", ":io_ray_ray_streaming-runtime",
"@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:de_ruedigermoeller_fst",
"@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_aeonbits_owner_owner",
"@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
@ -146,10 +153,12 @@ define_java_module(
"//java:io_ray_ray_api", "//java:io_ray_ray_api",
"//java:io_ray_ray_runtime", "//java:io_ray_ray_runtime",
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:com_google_protobuf_protobuf_java", "@ray_streaming_maven//:com_google_protobuf_protobuf_java",
"@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:de_ruedigermoeller_fst",
"@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_aeonbits_owner_owner",
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_msgpack_msgpack_core", "@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12",

View file

@ -6,8 +6,11 @@ def gen_streaming_java_deps():
artifacts = [ artifacts = [
"com.beust:jcommander:1.72", "com.beust:jcommander:1.72",
"com.google.guava:guava:27.0.1-jre", "com.google.guava:guava:27.0.1-jre",
"com.google.code.findbugs:jsr305:3.0.2",
"com.google.code.gson:gson:2.8.5",
"com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.protobuf:protobuf-java:3.8.0", "com.google.protobuf:protobuf-java:3.8.0",
"org.apache.commons:commons-lang3:3.4",
"de.ruedigermoeller:fst:2.57", "de.ruedigermoeller:fst:2.57",
"org.aeonbits.owner:owner:1.0.10", "org.aeonbits.owner:owner:1.0.10",
"org.slf4j:slf4j-api:1.7.12", "org.slf4j:slf4j-api:1.7.12",
@ -22,7 +25,6 @@ def gen_streaming_java_deps():
"org.mockito:mockito-all:1.10.19", "org.mockito:mockito-all:1.10.19",
"org.powermock:powermock-module-testng:1.6.6", "org.powermock:powermock-module-testng:1.6.6",
"org.powermock:powermock-api-mockito:1.6.6", "org.powermock:powermock-api-mockito:1.6.6",
"org.projectlombok:lombok:1.16.20",
], ],
repositories = [ repositories = [
"https://repo1.maven.org/maven2/", "https://repo1.maven.org/maven2/",

View file

@ -22,16 +22,36 @@
<artifactId>ray-api</artifactId> <artifactId>ray-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.ray</groupId> <groupId>org.ray</groupId>
<artifactId>streaming-state</artifactId> <artifactId>streaming-state</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId> <groupId>com.google.guava</groupId>
<artifactId>guava</artifactId> <artifactId>guava</artifactId>
<version>27.0.1-jre</version> <version>27.0.1-jre</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency> <dependency>
<groupId>org.slf4j</groupId> <groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId> <artifactId>slf4j-api</artifactId>

View file

@ -22,6 +22,11 @@
<artifactId>ray-api</artifactId> <artifactId>ray-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>io.ray</groupId>
<artifactId>ray-runtime</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.ray</groupId> <groupId>org.ray</groupId>
<artifactId>streaming-state</artifactId> <artifactId>streaming-state</artifactId>

View file

@ -0,0 +1,129 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class ClusterStarter {
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
Preconditions.checkArgument(Ray.internal() == null);
RayConfig.reset();
if (!isLocal) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
} else {
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
}
if (!isCrossLanguage) {
Ray.init();
return;
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
LOG.info("Delete existing socket file {}", file);
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.joining(":"));
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
config.put("num_workers_per_process_java", "1");
// Start ray cluster.
List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--include-java",
"--java-worker-options=" + workerOptions,
"--internal-config=" + new Gson().toJson(config)
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
public static synchronized void stopCluster(boolean isCrossLanguage) {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.clearProperty("ray.run-mode");
if (isCrossLanguage) {
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
LOG.info("Executing command: {}", String.join(" ", command));
try {
ProcessBuilder processBuilder = new ProcessBuilder(command)
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Process process = processBuilder.start();
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
if (!exit) {
process.destroyForcibly();
}
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
}

View file

@ -1,10 +1,12 @@
package io.ray.streaming.api.context; package io.ray.streaming.api.context;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.ray.api.Ray;
import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder; import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.schedule.JobScheduler; import io.ray.streaming.schedule.JobScheduler;
import io.ray.streaming.util.Config;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -13,11 +15,14 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.ServiceLoader; import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** /**
* Encapsulate the context information of a streaming Job. * Encapsulate the context information of a streaming Job.
*/ */
public class StreamingContext implements Serializable { public class StreamingContext implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
private transient AtomicInteger idGenerator; private transient AtomicInteger idGenerator;
@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
this.jobGraph = jobGraphBuilder.build(); this.jobGraph = jobGraphBuilder.build();
jobGraph.printJobGraph(); jobGraph.printJobGraph();
if (Ray.internal() == null) {
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
ClusterStarter.startCluster(false, true);
LOG.info("Created local cluster for job {}.", jobName);
} else {
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
LOG.info("Created multi process cluster for job {}.", jobName);
}
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
} else {
LOG.info("Reuse existing cluster.");
}
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class); ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
Iterator<JobScheduler> iterator = serviceLoader.iterator(); Iterator<JobScheduler> iterator = serviceLoader.iterator();
Preconditions.checkArgument(iterator.hasNext(), Preconditions.checkArgument(iterator.hasNext(),
@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
public void withConfig(Map<String, String> jobConfig) { public void withConfig(Map<String, String> jobConfig) {
this.jobConfig = jobConfig; this.jobConfig = jobConfig;
} }
public void stop() {
if (Ray.internal() != null) {
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
}
}
} }

View file

@ -1,6 +1,7 @@
package io.ray.streaming.api.stream; package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction; import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.FlatMapFunction;
@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
import io.ray.streaming.operator.impl.KeyByOperator; import io.ray.streaming.operator.impl.KeyByOperator;
import io.ray.streaming.operator.impl.MapOperator; import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.operator.impl.SinkOperator; import io.ray.streaming.operator.impl.SinkOperator;
import io.ray.streaming.python.stream.PythonDataStream;
/** /**
* Represents a stream of data. * Represents a stream of data.
* * <p>This class defines all the streaming operations.
* This class defines all the streaming operations.
* *
* @param <T> Type of data in the stream. * @param <T> Type of data in the stream.
*/ */
public class DataStream<T> extends Stream<T> { public class DataStream<T> extends Stream<DataStream<T>, T> {
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) { public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
super(streamingContext, streamOperator); super(streamingContext, streamOperator);
} }
public DataStream(DataStream input, StreamOperator streamOperator) { public DataStream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
super(streamingContext, streamOperator, partition);
}
public <R> DataStream(DataStream<R> input, StreamOperator streamOperator) {
super(input, streamOperator); super(input, streamOperator);
} }
public <R> DataStream(DataStream<R> input,
StreamOperator streamOperator,
Partition<T> partition) {
super(input, streamOperator, partition);
}
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public DataStream(PythonDataStream referencedStream) {
super(referencedStream);
}
/** /**
* Apply a map function to this stream. * Apply a map function to this stream.
* *
@ -41,7 +62,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream. * @return A new DataStream.
*/ */
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) { public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
return new DataStream<>(this, new MapOperator(mapFunction)); return new DataStream<>(this, new MapOperator<>(mapFunction));
} }
/** /**
@ -52,11 +73,11 @@ public class DataStream<T> extends Stream<T> {
* @return A new DataStream * @return A new DataStream
*/ */
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) { public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
return new DataStream(this, new FlatMapOperator(flatMapFunction)); return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
} }
public DataStream<T> filter(FilterFunction<T> filterFunction) { public DataStream<T> filter(FilterFunction<T> filterFunction) {
return new DataStream<T>(this, new FilterOperator(filterFunction)); return new DataStream<>(this, new FilterOperator<>(filterFunction));
} }
/** /**
@ -66,7 +87,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new UnionStream. * @return A new UnionStream.
*/ */
public UnionStream<T> union(DataStream<T> other) { public UnionStream<T> union(DataStream<T> other) {
return new UnionStream(this, null, other); return new UnionStream<>(this, null, other);
} }
/** /**
@ -93,7 +114,7 @@ public class DataStream<T> extends Stream<T> {
* @return A new StreamSink. * @return A new StreamSink.
*/ */
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) { public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
return new DataStreamSink<>(this, new SinkOperator(sinkFunction)); return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
} }
/** /**
@ -104,7 +125,8 @@ public class DataStream<T> extends Stream<T> {
* @return A new KeyDataStream. * @return A new KeyDataStream.
*/ */
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) { public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
return new KeyDataStream<>(this, new KeyByOperator(keyFunction)); checkPartitionCall();
return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
} }
/** /**
@ -113,8 +135,8 @@ public class DataStream<T> extends Stream<T> {
* @return This stream. * @return This stream.
*/ */
public DataStream<T> broadcast() { public DataStream<T> broadcast() {
this.partition = new BroadcastPartition<>(); checkPartitionCall();
return this; return setPartition(new BroadcastPartition<>());
} }
/** /**
@ -124,19 +146,32 @@ public class DataStream<T> extends Stream<T> {
* @return This stream. * @return This stream.
*/ */
public DataStream<T> partitionBy(Partition<T> partition) { public DataStream<T> partitionBy(Partition<T> partition) {
this.partition = partition; checkPartitionCall();
return this; return setPartition(partition);
} }
/** /**
* Set parallelism to current transformation. * If parent stream is a python stream, we can't call partition related methods
* * in the java stream.
* @param parallelism The parallelism to set.
* @return This stream.
*/ */
public DataStream<T> setParallelism(int parallelism) { private void checkPartitionCall() {
this.parallelism = parallelism; if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
return this; throw new RuntimeException("Partition related methods can't be called on a " +
"java stream if parent stream is a python stream.");
}
} }
/**
* Convert this stream as a python stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonDataStream asPythonStream() {
return new PythonDataStream(this);
}
@Override
public Language getLanguage() {
return Language.JAVA;
}
} }

View file

@ -1,5 +1,6 @@
package io.ray.streaming.api.stream; package io.ray.streaming.api.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.operator.impl.SinkOperator; import io.ray.streaming.operator.impl.SinkOperator;
/** /**
@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
*/ */
public class DataStreamSink<T> extends StreamSink<T> { public class DataStreamSink<T> extends StreamSink<T> {
public DataStreamSink(DataStream<T> input, SinkOperator sinkOperator) { public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
super(input, sinkOperator); super(input, sinkOperator);
this.streamingContext.addSink(this); getStreamingContext().addSink(this);
} }
public DataStreamSink<T> setParallelism(int parallelism) { @Override
this.parallelism = parallelism; public Language getLanguage() {
return this; return Language.JAVA;
} }
} }

View file

@ -14,9 +14,13 @@ import java.util.Collection;
*/ */
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> { public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
public DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) { private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
super(streamingContext, new SourceOperator<>(sourceFunction)); super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
super.partition = new RoundRobinPartition<>(); }
public static <T> DataStreamSource<T> fromSource(
StreamingContext context, SourceFunction<T> sourceFunction) {
return new DataStreamSource<>(context, sourceFunction);
} }
/** /**
@ -27,14 +31,9 @@ public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T
* @param <T> The type of source data. * @param <T> The type of source data.
* @return A DataStreamSource. * @return A DataStreamSource.
*/ */
public static <T> DataStreamSource<T> buildSource( public static <T> DataStreamSource<T> fromCollection(
StreamingContext context, Collection<T> values) { StreamingContext context, Collection<T> values) {
return new DataStreamSource(context, new CollectionSourceFunction(values)); return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
} }
@Override
public DataStreamSource<T> setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
} }

View file

@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.AggregateFunction; import io.ray.streaming.api.function.impl.AggregateFunction;
import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.KeyPartition; import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.impl.ReduceOperator; import io.ray.streaming.operator.impl.ReduceOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
/** /**
* Represents a DataStream returned by a key-by operation. * Represents a DataStream returned by a key-by operation.
@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
* @param <K> Type of the key. * @param <K> Type of the key.
* @param <T> Type of the data. * @param <T> Type of the data.
*/ */
@SuppressWarnings("unchecked")
public class KeyDataStream<K, T> extends DataStream<T> { public class KeyDataStream<K, T> extends DataStream<T> {
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) { public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
super(input, streamOperator); super(input, streamOperator, (Partition<T>) new KeyPartition<K, T>());
this.partition = new KeyPartition(); }
/**
* Create a java stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public KeyDataStream(PythonDataStream referencedStream) {
super(referencedStream);
} }
/** /**
@ -41,8 +52,13 @@ public class KeyDataStream<K, T> extends DataStream<T> {
return new DataStream<>(this, null); return new DataStream<>(this, null);
} }
public KeyDataStream<K, T> setParallelism(int parallelism) { /**
this.parallelism = parallelism; * Convert this stream as a python stream.
return this; * The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public PythonKeyDataStream asPythonStream() {
return new PythonKeyDataStream(this);
} }
} }

View file

@ -1,58 +1,99 @@
package io.ray.streaming.api.stream; package io.ray.streaming.api.stream;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition; import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition; import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition; import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStream;
import java.io.Serializable; import java.io.Serializable;
/** /**
* Abstract base class of all stream types. * Abstract base class of all stream types.
* *
* @param <S> Type of stream class
* @param <T> Type of the data in the stream. * @param <T> Type of the data in the stream.
*/ */
public abstract class Stream<T> implements Serializable { public abstract class Stream<S extends Stream<S, T>, T>
protected int id; implements Serializable {
protected int parallelism = 1; private final int id;
protected StreamOperator operator; private final StreamingContext streamingContext;
protected Stream<T> inputStream; private final Stream inputStream;
protected StreamingContext streamingContext; private final StreamOperator operator;
protected Partition<T> partition; private int parallelism = 1;
private Partition<T> partition;
private Stream originalStream;
@SuppressWarnings("unchecked")
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) { public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
this(streamingContext, null, streamOperator,
selectPartition(streamOperator));
}
public Stream(StreamingContext streamingContext,
StreamOperator streamOperator,
Partition<T> partition) {
this(streamingContext, null, streamOperator, partition);
}
public Stream(Stream inputStream, StreamOperator streamOperator) {
this(inputStream.getStreamingContext(), inputStream, streamOperator,
selectPartition(streamOperator));
}
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
}
protected Stream(StreamingContext streamingContext,
Stream inputStream,
StreamOperator streamOperator,
Partition<T> partition) {
this.streamingContext = streamingContext; this.streamingContext = streamingContext;
this.inputStream = inputStream;
this.operator = streamOperator; this.operator = streamOperator;
this.partition = partition;
this.id = streamingContext.generateId(); this.id = streamingContext.generateId();
if (streamOperator instanceof PythonOperator) { if (inputStream != null) {
this.partition = PythonPartition.RoundRobinPartition; this.parallelism = inputStream.getParallelism();
} else {
this.partition = new RoundRobinPartition<>();
} }
} }
public Stream(Stream<T> inputStream, StreamOperator streamOperator) { /**
this.inputStream = inputStream; * Create a proxy stream of original stream.
this.parallelism = inputStream.getParallelism(); * Changes in new stream will be reflected in original stream and vice versa
this.streamingContext = this.inputStream.getStreamingContext(); */
this.operator = streamOperator; protected Stream(Stream originalStream) {
this.id = streamingContext.generateId(); this.originalStream = originalStream;
this.partition = selectPartition(); this.id = originalStream.getId();
this.streamingContext = originalStream.getStreamingContext();
this.inputStream = originalStream.getInputStream();
this.operator = originalStream.getOperator();
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private Partition<T> selectPartition() { private static <T> Partition<T> selectPartition(Operator operator) {
if (inputStream instanceof PythonStream) { switch (operator.getLanguage()) {
return PythonPartition.RoundRobinPartition; case PYTHON:
} else { return (Partition<T>) PythonPartition.RoundRobinPartition;
case JAVA:
return new RoundRobinPartition<>(); return new RoundRobinPartition<>();
default:
throw new UnsupportedOperationException(
"Unsupported language " + operator.getLanguage());
} }
} }
public Stream<T> getInputStream() { public int getId() {
return id;
}
public StreamingContext getStreamingContext() {
return streamingContext;
}
public Stream getInputStream() {
return inputStream; return inputStream;
} }
@ -60,32 +101,47 @@ public abstract class Stream<T> implements Serializable {
return operator; return operator;
} }
public void setOperator(StreamOperator operator) { @SuppressWarnings("unchecked")
this.operator = operator; private S self() {
} return (S) this;
public StreamingContext getStreamingContext() {
return streamingContext;
} }
public int getParallelism() { public int getParallelism() {
return parallelism; return originalStream != null ? originalStream.getParallelism() : parallelism;
} }
public Stream<T> setParallelism(int parallelism) { public S setParallelism(int parallelism) {
if (originalStream != null) {
originalStream.setParallelism(parallelism);
} else {
this.parallelism = parallelism; this.parallelism = parallelism;
return this; }
} return self();
public int getId() {
return id;
} }
@SuppressWarnings("unchecked")
public Partition<T> getPartition() { public Partition<T> getPartition() {
return partition; return originalStream != null ? originalStream.getPartition() : partition;
} }
public void setPartition(Partition<T> partition) { @SuppressWarnings("unchecked")
protected S setPartition(Partition<T> partition) {
if (originalStream != null) {
originalStream.setPartition(partition);
} else {
this.partition = partition; this.partition = partition;
} }
return self();
}
public boolean isProxyStream() {
return originalStream != null;
}
public Stream getOriginalStream() {
Preconditions.checkArgument(isProxyStream());
return originalStream;
}
public abstract Language getLanguage();
} }

View file

@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
* *
* @param <T> Type of the input data of this sink. * @param <T> Type of the input data of this sink.
*/ */
public class StreamSink<T> extends Stream<T> { public abstract class StreamSink<T> extends Stream<StreamSink<T>, T> {
public StreamSink(Stream<T> inputStream, StreamOperator streamOperator) { public StreamSink(Stream inputStream, StreamOperator streamOperator) {
super(inputStream, streamOperator); super(inputStream, streamOperator);
} }
} }

View file

@ -11,15 +11,15 @@ import java.util.List;
*/ */
public class UnionStream<T> extends DataStream<T> { public class UnionStream<T> extends DataStream<T> {
private List<DataStream> unionStreams; private List<DataStream<T>> unionStreams;
public UnionStream(DataStream input, StreamOperator streamOperator, DataStream<T> other) { public UnionStream(DataStream<T> input, StreamOperator streamOperator, DataStream<T> other) {
super(input, streamOperator); super(input, streamOperator);
this.unionStreams = new ArrayList<>(); this.unionStreams = new ArrayList<>();
this.unionStreams.add(other); this.unionStreams.add(other);
} }
public List<DataStream> getUnionStreams() { public List<DataStream<T>> getUnionStreams() {
return unionStreams; return unionStreams;
} }
} }

View file

@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph; package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
} }
} }
public boolean isCrossLanguageGraph() {
Language language = jobVertexList.get(0).getLanguage();
for (JobVertex jobVertex : jobVertexList) {
if (jobVertex.getLanguage() != language) {
return true;
}
}
return false;
}
} }

View file

@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph; package io.ray.streaming.jobgraph;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream; import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSink;
@ -10,8 +11,11 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class JobGraphBuilder { public class JobGraphBuilder {
private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
private JobGraph jobGraph; private JobGraph jobGraph;
@ -41,12 +45,19 @@ public class JobGraphBuilder {
} }
private void processStream(Stream stream) { private void processStream(Stream stream) {
while (stream.isProxyStream()) {
// Proxy stream and original stream are the same logical stream, both refer to the
// same data flow transformation. We should skip proxy stream to avoid applying same
// transformation multiple times.
LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
stream = stream.getOriginalStream();
}
StreamOperator streamOperator = stream.getOperator();
Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
"Reference stream should be skipped.");
int vertexId = stream.getId(); int vertexId = stream.getId();
int parallelism = stream.getParallelism(); int parallelism = stream.getParallelism();
JobVertex jobVertex;
StreamOperator streamOperator = stream.getOperator();
JobVertex jobVertex = null;
if (stream instanceof StreamSink) { if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator); jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
Stream parentStream = stream.getInputStream(); Stream parentStream = stream.getInputStream();

View file

@ -1,6 +1,8 @@
package io.ray.streaming.message; package io.ray.streaming.message;
import java.util.Objects;
public class KeyRecord<K, T> extends Record<T> { public class KeyRecord<K, T> extends Record<T> {
private K key; private K key;
@ -17,4 +19,24 @@ public class KeyRecord<K, T> extends Record<T> {
public void setKey(K key) { public void setKey(K key) {
this.key = key; this.key = key;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
KeyRecord<?, ?> keyRecord = (KeyRecord<?, ?>) o;
return Objects.equals(key, keyRecord.key);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), key);
}
} }

View file

@ -1,64 +0,0 @@
package io.ray.streaming.message;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.List;
public class Message implements Serializable {
private int taskId;
private long batchId;
private String stream;
private List<Record> recordList;
public Message(int taskId, long batchId, String stream, List<Record> recordList) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = recordList;
}
public Message(int taskId, long batchId, String stream, Record record) {
this.taskId = taskId;
this.batchId = batchId;
this.stream = stream;
this.recordList = Lists.newArrayList(record);
}
public int getTaskId() {
return taskId;
}
public void setTaskId(int taskId) {
this.taskId = taskId;
}
public long getBatchId() {
return batchId;
}
public void setBatchId(long batchId) {
this.batchId = batchId;
}
public String getStream() {
return stream;
}
public void setStream(String stream) {
this.stream = stream;
}
public List<Record> getRecordList() {
return recordList;
}
public void setRecordList(List<Record> recordList) {
this.recordList = recordList;
}
public Record getRecord(int index) {
return recordList.get(0);
}
}

View file

@ -1,6 +1,7 @@
package io.ray.streaming.message; package io.ray.streaming.message;
import java.io.Serializable; import java.io.Serializable;
import java.util.Objects;
public class Record<T> implements Serializable { public class Record<T> implements Serializable {
@ -27,6 +28,24 @@ public class Record<T> implements Serializable {
this.stream = stream; this.stream = stream;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Record<?> record = (Record<?>) o;
return Objects.equals(stream, record.stream) &&
Objects.equals(value, record.value);
}
@Override
public int hashCode() {
return Objects.hash(stream, value);
}
@Override @Override
public String toString() { public String toString() {
return value.toString(); return value.toString();

View file

@ -1,6 +1,8 @@
package io.ray.streaming.python; package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.function.Function; import io.ray.streaming.api.function.Function;
import org.apache.commons.lang3.StringUtils;
/** /**
* Represents a user defined python function. * Represents a user defined python function.
@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
* *
* <p>If the python data stream api is invoked from python, `function` will be not null.</p> * <p>If the python data stream api is invoked from python, `function` will be not null.</p>
* <p>If the python data stream api is invoked from java, `moduleName` and * <p>If the python data stream api is invoked from java, `moduleName` and
* `className`/`functionName` will be not null.</p> * `functionName` will be not null.</p>
* <p> * <p>
* TODO serialize to bytes using protobuf
*/ */
public class PythonFunction implements Function { public class PythonFunction implements Function {
public enum FunctionInterface { public enum FunctionInterface {
@ -38,23 +39,43 @@ public class PythonFunction implements Function {
} }
} }
private byte[] function; // null if this function is constructed from moduleName/functionName.
private String moduleName; private final byte[] function;
private String className; // null if this function is constructed from serialized python function.
private String functionName; private final String moduleName;
// null if this function is constructed from serialized python function.
private final String functionName;
/** /**
* FunctionInterface can be used to validate python function, * FunctionInterface can be used to validate python function,
* and look up operator class from FunctionInterface. * and look up operator class from FunctionInterface.
*/ */
private String functionInterface; private String functionInterface;
private PythonFunction(byte[] function, /**
String moduleName, * Create a {@link PythonFunction} from a serialized streaming python function.
String className, *
String functionName) { * @param function serialized streaming python function from python driver.
*/
public PythonFunction(byte[] function) {
Preconditions.checkNotNull(function);
this.function = function; this.function = function;
this.moduleName = null;
this.functionName = null;
}
/**
* Create a {@link PythonFunction} from a moduleName and streaming function name.
*
* @param moduleName module name of streaming function.
* @param functionName function name of streaming function. {@code functionName} is the name
* of a python function, or class name of subclass of `ray.streaming.function.`
*/
public PythonFunction(String moduleName,
String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.function = null;
this.moduleName = moduleName; this.moduleName = moduleName;
this.className = className;
this.functionName = functionName; this.functionName = functionName;
} }
@ -70,10 +91,6 @@ public class PythonFunction implements Function {
return moduleName; return moduleName;
} }
public String getClassName() {
return className;
}
public String getFunctionName() { public String getFunctionName() {
return functionName; return functionName;
} }
@ -82,34 +99,4 @@ public class PythonFunction implements Function {
return functionInterface; return functionInterface;
} }
/**
* Create a {@link PythonFunction} using python serialized function
*
* @param function serialized python function sent from python driver
*/
public static PythonFunction fromFunction(byte[] function) {
return new PythonFunction(function, null, null, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>className</code>.
*
* @param moduleName python module name
* @param className python class name
*/
public static PythonFunction fromClassName(String moduleName, String className) {
return new PythonFunction(null, moduleName, className, null);
}
/**
* Create a {@link PythonFunction} using <code>moduleName</code> and
* <code>functionName</code>.
*
* @param moduleName python module name
* @param functionName python function name
*/
public static PythonFunction fromFunctionName(String moduleName, String functionName) {
return new PythonFunction(null, moduleName, null, functionName);
}
} }

View file

@ -1,6 +1,8 @@
package io.ray.streaming.python; package io.ray.streaming.python;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.partition.Partition; import io.ray.streaming.api.partition.Partition;
import org.apache.commons.lang3.StringUtils;
/** /**
* Represents a python partition function. * Represents a python partition function.
@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
* If this object is constructed from moduleName and className/functionName, * If this object is constructed from moduleName and className/functionName,
* python worker will use `importlib` to load python partition function. * python worker will use `importlib` to load python partition function.
* <p> * <p>
* TODO serialize to bytes using protobuf
*/ */
public class PythonPartition implements Partition { public class PythonPartition implements Partition<Object> {
public static final PythonPartition BroadcastPartition = new PythonPartition( public static final PythonPartition BroadcastPartition = new PythonPartition(
"ray.streaming.partition", "BroadcastPartition", null); "ray.streaming.partition", "BroadcastPartition");
public static final PythonPartition KeyPartition = new PythonPartition( public static final PythonPartition KeyPartition = new PythonPartition(
"ray.streaming.partition", "KeyPartition", null); "ray.streaming.partition", "KeyPartition");
public static final PythonPartition RoundRobinPartition = new PythonPartition( public static final PythonPartition RoundRobinPartition = new PythonPartition(
"ray.streaming.partition", "RoundRobinPartition", null); "ray.streaming.partition", "RoundRobinPartition");
private byte[] partition; private byte[] partition;
private String moduleName; private String moduleName;
private String className;
private String functionName; private String functionName;
public PythonPartition(byte[] partition) { public PythonPartition(byte[] partition) {
Preconditions.checkNotNull(partition);
this.partition = partition; this.partition = partition;
} }
public PythonPartition(String moduleName, String className, String functionName) { /**
* Create a python partition from a moduleName and partition function name
* @param moduleName module name of python partition
* @param functionName function/class name of the partition function.
*/
public PythonPartition(String moduleName, String functionName) {
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
this.moduleName = moduleName; this.moduleName = moduleName;
this.className = className;
this.functionName = functionName; this.functionName = functionName;
} }
@ -53,10 +60,6 @@ public class PythonPartition implements Partition {
return moduleName; return moduleName;
} }
public String getClassName() {
return className;
}
public String getFunctionName() { public String getFunctionName() {
return functionName; return functionName;
} }

View file

@ -1,6 +1,9 @@
package io.ray.streaming.python.stream; package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream; import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonFunction.FunctionInterface;
@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition;
/** /**
* Represents a stream of data whose transformations will be executed in python. * Represents a stream of data whose transformations will be executed in python.
*/ */
public class PythonDataStream extends Stream implements PythonStream { public class PythonDataStream extends Stream<PythonDataStream, Object> implements PythonStream {
protected PythonDataStream(StreamingContext streamingContext, protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator) { PythonOperator pythonOperator) {
super(streamingContext, pythonOperator); super(streamingContext, pythonOperator);
} }
protected PythonDataStream(StreamingContext streamingContext,
PythonOperator pythonOperator,
Partition<Object> partition) {
super(streamingContext, pythonOperator, partition);
}
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) { public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator); super(input, pythonOperator);
} }
protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) { public PythonDataStream(PythonDataStream input,
super(inputStream, pythonOperator); PythonOperator pythonOperator,
Partition<Object> partition) {
super(input, pythonOperator, partition);
}
/**
* Create a python stream that reference passed java stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream map(String moduleName, String funcName) {
return map(new PythonFunction(moduleName, funcName));
} }
/** /**
@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func)); return new PythonDataStream(this, new PythonOperator(func));
} }
public PythonDataStream flatMap(String moduleName, String funcName) {
return flatMap(new PythonFunction(moduleName, funcName));
}
/** /**
* Apply a flat-map function to this stream. * Apply a flat-map function to this stream.
* *
@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func)); return new PythonDataStream(this, new PythonOperator(func));
} }
public PythonDataStream filter(String moduleName, String funcName) {
return filter(new PythonFunction(moduleName, funcName));
}
/** /**
* Apply a filter function to this stream. * Apply a filter function to this stream.
* *
@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonDataStream(this, new PythonOperator(func)); return new PythonDataStream(this, new PythonOperator(func));
} }
public PythonStreamSink sink(String moduleName, String funcName) {
return sink(new PythonFunction(moduleName, funcName));
}
/** /**
* Apply a sink function and get a StreamSink. * Apply a sink function and get a StreamSink.
* *
@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream {
return new PythonStreamSink(this, new PythonOperator(func)); return new PythonStreamSink(this, new PythonOperator(func));
} }
public PythonKeyDataStream keyBy(String moduleName, String funcName) {
return keyBy(new PythonFunction(moduleName, funcName));
}
/** /**
* Apply a key-by function to this stream. * Apply a key-by function to this stream.
* *
@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return A new KeyDataStream. * @return A new KeyDataStream.
*/ */
public PythonKeyDataStream keyBy(PythonFunction func) { public PythonKeyDataStream keyBy(PythonFunction func) {
checkPartitionCall();
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION); func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
return new PythonKeyDataStream(this, new PythonOperator(func)); return new PythonKeyDataStream(this, new PythonOperator(func));
} }
@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream. * @return This stream.
*/ */
public PythonDataStream broadcast() { public PythonDataStream broadcast() {
this.partition = PythonPartition.BroadcastPartition; checkPartitionCall();
return this; return setPartition(PythonPartition.BroadcastPartition);
} }
/** /**
@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream {
* @return This stream. * @return This stream.
*/ */
public PythonDataStream partitionBy(PythonPartition partition) { public PythonDataStream partitionBy(PythonPartition partition) {
this.partition = partition; checkPartitionCall();
return this; return setPartition(partition);
} }
/** /**
* Set parallelism to current transformation. * If parent stream is a python stream, we can't call partition related methods
* * in the java stream.
* @param parallelism The parallelism to set.
* @return This stream.
*/ */
public PythonDataStream setParallelism(int parallelism) { private void checkPartitionCall() {
this.parallelism = parallelism; if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) {
return this; throw new RuntimeException("Partition related methods can't be called on a " +
"python stream if parent stream is a java stream.");
}
}
/**
* Convert this stream as a java stream.
* The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public DataStream<Object> asJavaStream() {
return new DataStream<>(this);
}
@Override
public Language getLanguage() {
return Language.PYTHON;
} }
} }

View file

@ -1,5 +1,7 @@
package io.ray.streaming.python.stream; package io.ray.streaming.python.stream;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.KeyDataStream;
import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonFunction.FunctionInterface;
import io.ray.streaming.python.PythonOperator; import io.ray.streaming.python.PythonOperator;
@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition;
/** /**
* Represents a python DataStream returned by a key-by operation. * Represents a python DataStream returned by a key-by operation.
*/ */
@SuppressWarnings("unchecked")
public class PythonKeyDataStream extends PythonDataStream implements PythonStream { public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) { public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
super(input, pythonOperator); super(input, pythonOperator, PythonPartition.KeyPartition);
this.partition = PythonPartition.KeyPartition; }
/**
* Create a python stream that reference passed python stream.
* Changes in new stream will be reflected in referenced stream and vice versa
*/
public PythonKeyDataStream(DataStream referencedStream) {
super(referencedStream);
}
public PythonDataStream reduce(String moduleName, String funcName) {
return reduce(new PythonFunction(moduleName, funcName));
} }
/** /**
@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
return new PythonDataStream(this, new PythonOperator(func)); return new PythonDataStream(this, new PythonOperator(func));
} }
public PythonKeyDataStream setParallelism(int parallelism) { /**
this.parallelism = parallelism; * Convert this stream as a java stream.
return this; * The converted stream and this stream are the same logical stream, which has same stream id.
* Changes in converted stream will be reflected in this stream and vice versa.
*/
public KeyDataStream<Object, Object> asJavaStream() {
return new KeyDataStream(this);
} }
} }

View file

@ -1,5 +1,6 @@
package io.ray.streaming.python.stream; package io.ray.streaming.python.stream;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.python.PythonOperator; import io.ray.streaming.python.PythonOperator;
@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator;
public class PythonStreamSink extends StreamSink implements PythonStream { public class PythonStreamSink extends StreamSink implements PythonStream {
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) { public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
super(input, sinkOperator); super(input, sinkOperator);
this.streamingContext.addSink(this); getStreamingContext().addSink(this);
} }
public PythonStreamSink setParallelism(int parallelism) { @Override
this.parallelism = parallelism; public Language getLanguage() {
return this; return Language.PYTHON;
} }
} }

View file

@ -13,13 +13,8 @@ import io.ray.streaming.python.PythonPartition;
public class PythonStreamSource extends PythonDataStream implements StreamSource { public class PythonStreamSource extends PythonDataStream implements StreamSource {
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) { private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
super(streamingContext, new PythonOperator(sourceFunction)); super(streamingContext, new PythonOperator(sourceFunction),
super.partition = PythonPartition.RoundRobinPartition; PythonPartition.RoundRobinPartition);
}
public PythonStreamSource setParallelism(int parallelism) {
this.parallelism = parallelism;
return this;
} }
public static PythonStreamSource from(StreamingContext streamingContext, public static PythonStreamSource from(StreamingContext streamingContext,

View file

@ -21,7 +21,6 @@ public class Config {
public static final String CHANNEL_TYPE = "channel_type"; public static final String CHANNEL_TYPE = "channel_type";
public static final String MEMORY_CHANNEL = "memory_channel"; public static final String MEMORY_CHANNEL = "memory_channel";
public static final String NATIVE_CHANNEL = "native_channel"; public static final String NATIVE_CHANNEL = "native_channel";
public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL;
public static final String CHANNEL_SIZE = "channel_size"; public static final String CHANNEL_SIZE = "channel_size";
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8)); public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
public static final String IS_RECREATE = "streaming.is_recreate"; public static final String IS_RECREATE = "streaming.is_recreate";

View file

@ -0,0 +1,40 @@
package io.ray.streaming.api.stream;
import static org.testng.Assert.assertEquals;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.python.stream.PythonDataStream;
import io.ray.streaming.python.stream.PythonKeyDataStream;
import org.testng.annotations.Test;
@SuppressWarnings("unchecked")
public class StreamTest {
@Test
public void testReferencedDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
PythonDataStream pythonDataStream = dataStream.asPythonStream();
DataStream javaStream = pythonDataStream.asJavaStream();
assertEquals(dataStream.getId(), pythonDataStream.getId());
assertEquals(dataStream.getId(), javaStream.getId());
javaStream.setParallelism(10);
assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism());
assertEquals(dataStream.getParallelism(), javaStream.getParallelism());
}
@Test
public void testReferencedKeyDataStream() {
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
new MapOperator(value -> null));
KeyDataStream keyDataStream = dataStream.keyBy(value -> null);
PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream();
KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream();
assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId());
assertEquals(keyDataStream.getId(), javaKeyDataStream.getId());
javaKeyDataStream.setParallelism(10);
assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism());
assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism());
}
}

View file

@ -38,7 +38,7 @@ public class JobGraphBuilderTest {
public JobGraph buildDataSyncJobGraph() { public JobGraph buildDataSyncJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext(); StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext, DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c")); Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
@ -73,7 +73,7 @@ public class JobGraphBuilderTest {
public JobGraph buildKeyByJobGraph() { public JobGraph buildKeyByJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext(); StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext, DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("1", "2", "3", "4")); Lists.newArrayList("1", "2", "3", "4"));
StreamSink streamSink = dataStream.keyBy(x -> x) StreamSink streamSink = dataStream.keyBy(x -> x)
.sink(x -> LOG.info(x)); .sink(x -> LOG.info(x));

View file

@ -36,6 +36,11 @@
<artifactId>flatbuffers-java</artifactId> <artifactId>flatbuffers-java</artifactId>
<version>1.9.0.1</version> <version>1.9.0.1</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>3.0.2</version>
</dependency>
<dependency> <dependency>
<groupId>com.google.guava</groupId> <groupId>com.google.guava</groupId>
<artifactId>guava</artifactId> <artifactId>guava</artifactId>
@ -56,6 +61,11 @@
<artifactId>owner</artifactId> <artifactId>owner</artifactId>
<version>1.0.10</version> <version>1.0.10</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency> <dependency>
<groupId>org.mockito</groupId> <groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId> <artifactId>mockito-all</artifactId>
@ -71,11 +81,6 @@
<artifactId>powermock-api-mockito</artifactId> <artifactId>powermock-api-mockito</artifactId>
<version>1.6.6</version> <version>1.6.6</version>
</dependency> </dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-core</artifactId>
<version>1.6.6</version>
</dependency>
<dependency> <dependency>
<groupId>org.powermock</groupId> <groupId>org.powermock</groupId>
<artifactId>powermock-module-testng</artifactId> <artifactId>powermock-module-testng</artifactId>

View file

@ -1,9 +1,14 @@
package io.ray.streaming.runtime.core.collector; package io.ray.streaming.runtime.core.collector;
import io.ray.runtime.serializer.Serializer; import io.ray.api.BaseActor;
import io.ray.api.RayPyActor;
import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.partition.Partition; import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.message.Record; import io.ray.streaming.message.Record;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.ChannelID; import io.ray.streaming.runtime.transfer.ChannelID;
import io.ray.streaming.runtime.transfer.DataWriter; import io.ray.streaming.runtime.transfer.DataWriter;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory;
public class OutputCollector implements Collector<Record> { public class OutputCollector implements Collector<Record> {
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class); private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
private Partition partition; private final DataWriter writer;
private DataWriter writer; private final ChannelID[] outputQueues;
private ChannelID[] outputQueues; private final Collection<BaseActor> targetActors;
private final Language[] targetLanguages;
private final Partition partition;
private final Serializer javaSerializer = new JavaSerializer();
private final Serializer crossLangSerializer = new CrossLangSerializer();
public OutputCollector(Collection<String> outputQueueIds, public OutputCollector(DataWriter writer,
DataWriter writer, Collection<String> outputQueueIds,
Collection<BaseActor> targetActors,
Partition partition) { Partition partition) {
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
this.writer = writer; this.writer = writer;
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
this.targetActors = targetActors;
this.targetLanguages = targetActors.stream()
.map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA)
.toArray(Language[]::new);
this.partition = partition; this.partition = partition;
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.", LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
outputQueueIds, this.partition); outputQueueIds, this.partition);
@ -31,9 +45,32 @@ public class OutputCollector implements Collector<Record> {
@Override @Override
public void collect(Record record) { public void collect(Record record) {
int[] partitions = this.partition.partition(record, outputQueues.length); int[] partitions = this.partition.partition(record, outputQueues.length);
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft()); ByteBuffer javaBuffer = null;
ByteBuffer crossLangBuffer = null;
for (int partition : partitions) { for (int partition : partitions) {
writer.write(outputQueues[partition], msgBuffer); if (targetLanguages[partition] == Language.JAVA) {
// avoid repeated serialization
if (javaBuffer == null) {
byte[] bytes = javaSerializer.serialize(record);
javaBuffer = ByteBuffer.allocate(1 + bytes.length);
javaBuffer.put(Serializer.JAVA_TYPE_ID);
// TODO(chaokunyang) remove copy
javaBuffer.put(bytes);
javaBuffer.flip();
}
writer.write(outputQueues[partition], javaBuffer.duplicate());
} else {
// avoid repeated serialization
if (crossLangBuffer == null) {
byte[] bytes = crossLangSerializer.serialize(record);
crossLangBuffer = ByteBuffer.allocate(1 + bytes.length);
crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID);
// TODO(chaokunyang) remove copy
crossLangBuffer.put(bytes);
crossLangBuffer.flip();
}
writer.write(outputQueues[partition], crossLangBuffer.duplicate());
}
} }
} }

View file

@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionTask; import io.ray.streaming.runtime.core.graph.ExecutionTask;
import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.generated.Streaming; import io.ray.streaming.runtime.generated.Streaming;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays; import java.util.Arrays;
public class GraphPbBuilder { public class GraphPbBuilder {
@ -74,11 +75,10 @@ public class GraphPbBuilder {
private byte[] serializeFunction(Function function) { private byte[] serializeFunction(Function function) {
if (function instanceof PythonFunction) { if (function instanceof PythonFunction) {
PythonFunction pyFunc = (PythonFunction) function; PythonFunction pyFunc = (PythonFunction) function;
// function_bytes, module_name, class_name, function_name, function_interface // function_bytes, module_name, function_name, function_interface
return serializer.serialize(Arrays.asList( return serializer.serialize(Arrays.asList(
pyFunc.getFunction(), pyFunc.getModuleName(), pyFunc.getFunction(), pyFunc.getModuleName(),
pyFunc.getClassName(), pyFunc.getFunctionName(), pyFunc.getFunctionName(), pyFunc.getFunctionInterface()
pyFunc.getFunctionInterface()
)); ));
} else { } else {
return new byte[0]; return new byte[0];
@ -88,10 +88,10 @@ public class GraphPbBuilder {
private byte[] serializePartition(Partition partition) { private byte[] serializePartition(Partition partition) {
if (partition instanceof PythonPartition) { if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition; PythonPartition pythonPartition = (PythonPartition) partition;
// partition_bytes, module_name, class_name, function_name // partition_bytes, module_name, function_name
return serializer.serialize(Arrays.asList( return serializer.serialize(Arrays.asList(
pythonPartition.getPartition(), pythonPartition.getModuleName(), pythonPartition.getPartition(), pythonPartition.getModuleName(),
pythonPartition.getClassName(), pythonPartition.getFunctionName() pythonPartition.getFunctionName()
)); ));
} else { } else {
return new byte[0]; return new byte[0];

View file

@ -1,16 +1,21 @@
package io.ray.streaming.runtime.python; package io.ray.streaming.runtime.python;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Primitives;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonPartition; import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStreamSource; import io.ray.streaming.python.stream.PythonStreamSource;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import io.ray.streaming.runtime.util.ReflectionUtils; import io.ray.streaming.runtime.util.ReflectionUtils;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.msgpack.core.Preconditions;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -68,7 +73,7 @@ public class PythonGateway {
Preconditions.checkNotNull(streamingContext); Preconditions.checkNotNull(streamingContext);
try { try {
PythonStreamSource pythonStreamSource = PythonStreamSource.from( PythonStreamSource pythonStreamSource = PythonStreamSource.from(
streamingContext, PythonFunction.fromFunction(pySourceFunc)); streamingContext, new PythonFunction(pySourceFunc));
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource); referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
return serializer.serialize(getReferenceId(pythonStreamSource)); return serializer.serialize(getReferenceId(pythonStreamSource));
} catch (Exception e) { } catch (Exception e) {
@ -84,7 +89,7 @@ public class PythonGateway {
} }
public byte[] createPyFunc(byte[] pyFunc) { public byte[] createPyFunc(byte[] pyFunc) {
PythonFunction function = PythonFunction.fromFunction(pyFunc); PythonFunction function = new PythonFunction(pyFunc);
referenceMap.put(getReferenceId(function), function); referenceMap.put(getReferenceId(function), function);
return serializer.serialize(getReferenceId(function)); return serializer.serialize(getReferenceId(function));
} }
@ -98,15 +103,21 @@ public class PythonGateway {
public byte[] callFunction(byte[] paramsBytes) { public byte[] callFunction(byte[] paramsBytes) {
try { try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes); List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params); params = processParameters(params);
LOG.info("callFunction params {}", params); LOG.info("callFunction params {}", params);
String className = (String) params.get(0); String className = (String) params.get(0);
String funcName = (String) params.get(1); String funcName = (String) params.get(1);
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader()); Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
Method method = ReflectionUtils.findMethod(clz, funcName); Class[] paramsTypes = params.subList(2, params.size()).stream()
.map(Object::getClass).toArray(Class[]::new);
Method method = findMethod(clz, funcName, paramsTypes);
Object result = method.invoke(null, params.subList(2, params.size()).toArray()); Object result = method.invoke(null, params.subList(2, params.size()).toArray());
if (returnReference(result)) {
referenceMap.put(getReferenceId(result), result); referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result)); return serializer.serialize(getReferenceId(result));
} else {
return serializer.serialize(result);
}
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -115,31 +126,78 @@ public class PythonGateway {
public byte[] callMethod(byte[] paramsBytes) { public byte[] callMethod(byte[] paramsBytes) {
try { try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes); List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params); params = processParameters(params);
LOG.info("callMethod params {}", params); LOG.info("callMethod params {}", params);
Object obj = params.get(0); Object obj = params.get(0);
String methodName = (String) params.get(1); String methodName = (String) params.get(1);
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName); Class<?> clz = obj.getClass();
Class[] paramsTypes = params.subList(2, params.size()).stream()
.map(Object::getClass).toArray(Class[]::new);
Method method = findMethod(clz, methodName, paramsTypes);
Object result = method.invoke(obj, params.subList(2, params.size()).toArray()); Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
if (returnReference(result)) {
referenceMap.put(getReferenceId(result), result); referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result)); return serializer.serialize(getReferenceId(result));
} else {
return serializer.serialize(result);
}
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
private List<Object> processReferenceParameters(List<Object> params) { private static Method findMethod(Class<?> cls, String methodName, Class[] paramsTypes) {
return params.stream().map(this::processReferenceParameter) List<Method> methods = ReflectionUtils.findMethods(cls, methodName);
if (methods.size() == 1) {
return methods.get(0);
}
// Convert all params types to primitive types if it's boxed type
Class[] unwrappedTypes = Arrays.stream(paramsTypes)
.map((Function<Class, Class>) Primitives::unwrap)
.toArray(Class[]::new);
Optional<Method> any = methods.stream()
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
.findAny();
Preconditions.checkArgument(any.isPresent(),
String.format("Method %s with type %s doesn't exist on class %s",
methodName, Arrays.toString(paramsTypes), cls));
return any.get();
}
private static boolean returnReference(Object value) {
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
}
public byte[] newInstance(byte[] classNameBytes) {
String className = (String) serializer.deserialize(classNameBytes);
try {
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
Object instance = clz.newInstance();
referenceMap.put(getReferenceId(instance), instance);
return serializer.serialize(getReferenceId(instance));
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
throw new IllegalArgumentException(
String.format("Create instance for class %s failed", className), e);
}
}
private List<Object> processParameters(List<Object> params) {
return params.stream().map(this::processParameter)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private Object processReferenceParameter(Object o) { private Object processParameter(Object o) {
if (o instanceof String) { if (o instanceof String) {
Object value = referenceMap.get(o); Object value = referenceMap.get(o);
if (value != null) { if (value != null) {
return value; return value;
} }
} }
// Since python can't represent byte/short, we convert all Byte/Short to Integer
if (o instanceof Byte || o instanceof Short) {
return ((Number) o).intValue();
}
return o; return o;
} }

View file

@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler {
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) { public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
this.jobConfig = jobConfig; this.jobConfig = jobConfig;
this.jobGraph = jobGraph; this.jobGraph = jobGraph;
if (Ray.internal() == null) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
Ray.init();
}
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph); ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList(); List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
boolean hasPythonNode = executionNodes.stream() boolean hasPythonNode = executionNodes.stream()
.allMatch(node -> node.getLanguage() == Language.PYTHON); .anyMatch(node -> node.getLanguage() == Language.PYTHON);
RemoteCall.ExecutionGraph executionGraphPb = null; RemoteCall.ExecutionGraph executionGraphPb = null;
if (hasPythonNode) { if (hasPythonNode) {
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph); executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);

View file

@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule;
import io.ray.api.BaseActor; import io.ray.api.BaseActor;
import io.ray.api.Ray; import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.RayPyActor;
import io.ray.api.function.PyActorClass; import io.ray.api.function.PyActorClass;
import io.ray.streaming.jobgraph.JobEdge; import io.ray.streaming.jobgraph.JobEdge;
import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraph;
@ -15,8 +17,11 @@ import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TaskAssignerImpl implements TaskAssigner { public class TaskAssignerImpl implements TaskAssigner {
private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class);
/** /**
* Assign an optimized logical plan to execution graph. * Assign an optimized logical plan to execution graph.
@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner {
private BaseActor createWorker(JobVertex jobVertex) { private BaseActor createWorker(JobVertex jobVertex) {
switch (jobVertex.getLanguage()) { switch (jobVertex.getLanguage()) {
case PYTHON: case PYTHON: {
return Ray.createActor( RayPyActor worker = Ray.createActor(
new PyActorClass("ray.streaming.runtime.worker", "JobWorker")); new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
case JAVA: LOG.info("Created python worker {}", worker);
return Ray.createActor(JobWorker::new); return worker;
}
case JAVA: {
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
LOG.info("Created java worker {}", worker);
return worker;
}
default: default:
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Unsupported language " + jobVertex.getLanguage()); "Unsupported language " + jobVertex.getLanguage());

View file

@ -0,0 +1,62 @@
package io.ray.streaming.runtime.serialization;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import java.util.Arrays;
import java.util.List;
/**
* A serializer for cross-lang serialization between java/python.
* TODO implements a more sophisticated serialization framework
*/
public class CrossLangSerializer implements Serializer {
private static final byte RECORD_TYPE_ID = 0;
private static final byte KEY_RECORD_TYPE_ID = 1;
private MsgPackSerializer msgPackSerializer = new MsgPackSerializer();
public byte[] serialize(Object object) {
Record record = (Record) object;
Object value = record.getValue();
Class<? extends Record> clz = record.getClass();
if (clz == Record.class) {
return msgPackSerializer.serialize(Arrays.asList(
RECORD_TYPE_ID, record.getStream(), value));
} else if (clz == KeyRecord.class) {
KeyRecord keyRecord = (KeyRecord) record;
Object key = keyRecord.getKey();
return msgPackSerializer.serialize(Arrays.asList(
KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value));
} else {
throw new UnsupportedOperationException(
String.format("Serialize %s is unsupported.", record));
}
}
@SuppressWarnings("unchecked")
public Record deserialize(byte[] bytes) {
List list = (List) msgPackSerializer.deserialize(bytes);
Byte typeId = (Byte) list.get(0);
switch (typeId) {
case RECORD_TYPE_ID: {
String stream = (String) list.get(1);
Object value = list.get(2);
Record record = new Record(value);
record.setStream(stream);
return record;
}
case KEY_RECORD_TYPE_ID: {
String stream = (String) list.get(1);
Object key = list.get(2);
Object value = list.get(3);
KeyRecord keyRecord = new KeyRecord(key, value);
keyRecord.setStream(stream);
return keyRecord;
}
default:
throw new UnsupportedOperationException("Unsupported type " + typeId);
}
}
}

View file

@ -0,0 +1,15 @@
package io.ray.streaming.runtime.serialization;
import io.ray.runtime.serializer.FstSerializer;
public class JavaSerializer implements Serializer {
@Override
public byte[] serialize(Object object) {
return FstSerializer.encode(object);
}
@Override
public <T> T deserialize(byte[] bytes) {
return FstSerializer.decode(bytes);
}
}

View file

@ -1,4 +1,4 @@
package io.ray.streaming.runtime.python; package io.ray.streaming.runtime.serialization;
import com.google.common.io.BaseEncoding; import com.google.common.io.BaseEncoding;
import java.util.ArrayList; import java.util.ArrayList;
@ -31,6 +31,10 @@ public class MsgPackSerializer {
Class<?> clz = obj.getClass(); Class<?> clz = obj.getClass();
if (clz == Boolean.class) { if (clz == Boolean.class) {
packer.packBoolean((Boolean) obj); packer.packBoolean((Boolean) obj);
} else if (clz == Byte.class) {
packer.packByte((Byte) obj);
} else if (clz == Short.class) {
packer.packShort((Short) obj);
} else if (clz == Integer.class) { } else if (clz == Integer.class) {
packer.packInt((Integer) obj); packer.packInt((Integer) obj);
} else if (clz == Long.class) { } else if (clz == Long.class) {
@ -84,7 +88,11 @@ public class MsgPackSerializer {
return value.asBooleanValue().getBoolean(); return value.asBooleanValue().getBoolean();
case INTEGER: case INTEGER:
IntegerValue iv = value.asIntegerValue(); IntegerValue iv = value.asIntegerValue();
if (iv.isInIntRange()) { if (iv.isInByteRange()) {
return iv.toByte();
} else if (iv.isInShortRange()) {
return iv.toShort();
} else if (iv.isInIntRange()) {
return iv.toInt(); return iv.toInt();
} else if (iv.isInLongRange()) { } else if (iv.isInLongRange()) {
return iv.toLong(); return iv.toLong();

View file

@ -0,0 +1,12 @@
package io.ray.streaming.runtime.serialization;
public interface Serializer {
byte CROSS_LANG_TYPE_ID = 0;
byte JAVA_TYPE_ID = 1;
byte PYTHON_TYPE_ID = 2;
byte[] serialize(Object object);
<T> T deserialize(byte[] bytes);
}

View file

@ -20,7 +20,7 @@ import java.util.Map;
*/ */
public class ChannelCreationParametersBuilder { public class ChannelCreationParametersBuilder {
public class Parameter { public static class Parameter {
private ActorId actorId; private ActorId actorId;
private FunctionDescriptor asyncFunctionDescriptor; private FunctionDescriptor asyncFunctionDescriptor;
@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder {
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc); parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc); parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
} else { } else {
Preconditions.checkArgument(false, "Invalid actor type"); throw new IllegalArgumentException("Invalid actor type");
} }
parameters.add(parameter); parameters.add(parameter);
} }
@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder {
} }
public String toString() { public String toString() {
String str = ""; StringBuilder str = new StringBuilder();
for (Parameter param : parameters) { for (Parameter param : parameters) {
str += param.toString(); str.append(param.toString());
} }
return str; return str.toString();
} }
} }

View file

@ -40,7 +40,7 @@ public class DataReader {
} }
long timerInterval = Long.parseLong( long timerInterval = Long.parseLong(
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1")); conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false; boolean isMock = false;
if (Config.MEMORY_CHANNEL.equals(channelType)) { if (Config.MEMORY_CHANNEL.equals(channelType)) {
isMock = true; isMock = true;

View file

@ -37,7 +37,7 @@ public class DataWriter {
Map<String, String> conf) { Map<String, String> conf) {
Preconditions.checkArgument(!outputChannels.isEmpty()); Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size()); Preconditions.checkArgument(outputChannels.size() == toActors.size());
ChannelCreationParametersBuilder initialParameters = ChannelCreationParametersBuilder initParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream() byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new); .map(ChannelID::idStrToBytes).toArray(byte[][]::new);
@ -47,13 +47,14 @@ public class DataWriter {
for (int i = 0; i < outputChannels.size(); i++) { for (int i = 0; i < outputChannels.size(); i++) {
msgIds[i] = 0; msgIds[i] = 0;
} }
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false; boolean isMock = false;
if (Config.MEMORY_CHANNEL.equals(channelType)) { if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) {
isMock = true; isMock = true;
LOGGER.info("Using memory channel");
} }
this.nativeWriterPtr = createWriterNative( this.nativeWriterPtr = createWriterNative(
initialParameters, initParameters,
outputChannelsBytes, outputChannelsBytes,
msgIds, msgIds,
channelSize, channelSize,

View file

@ -19,6 +19,7 @@ public class ReflectionUtils {
/** /**
* For covariant return type, return the most specific method. * For covariant return type, return the most specific method.
*
* @return all methods named by {@code methodName}, * @return all methods named by {@code methodName},
*/ */
public static List<Method> findMethods(Class<?> cls, String methodName) { public static List<Method> findMethods(Class<?> cls, String methodName) {

View file

@ -1,5 +1,6 @@
package io.ray.streaming.runtime.worker; package io.ray.streaming.runtime.worker;
import io.ray.api.Ray;
import io.ray.streaming.runtime.core.graph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.ExecutionNode; import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext;
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask; import io.ray.streaming.runtime.worker.tasks.StreamTask;
import io.ray.streaming.util.Config;
import java.io.Serializable; import java.io.Serializable;
import java.util.Map; import java.util.Map;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory;
*/ */
public class JobWorker implements Serializable { public class JobWorker implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class); private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
// special flag to indicate this actor not ready
private static final byte[] NOT_READY_FLAG = new byte[4];
static { static {
EnvUtil.loadNativeLibraries(); EnvUtil.loadNativeLibraries();
@ -54,11 +54,10 @@ public class JobWorker implements Serializable {
this.nodeType = executionNode.getNodeType(); this.nodeType = executionNode.getNodeType();
this.streamProcessor = ProcessBuilder this.streamProcessor = ProcessBuilder
.buildProcessor(executionNode.getStreamOperator()); .buildProcessor(executionNode.getStreamOperator());
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor); LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.",
EnvUtil.getJvmPid(), taskId, streamProcessor);
String channelType = (String) this.config.getOrDefault( if (!Ray.getRuntimeContext().isSingleProcess()) {
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
if (channelType.equals(Config.NATIVE_CHANNEL)) {
transferHandler = new TransferHandler(); transferHandler = new TransferHandler();
} }
task = createStreamTask(); task = createStreamTask();
@ -124,6 +123,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor * and receive result from this actor
*/ */
public byte[] onReaderMessageSync(byte[] buffer) { public byte[] onReaderMessageSync(byte[] buffer) {
if (transferHandler == null) {
return NOT_READY_FLAG;
}
return transferHandler.onReaderMessageSync(buffer); return transferHandler.onReaderMessageSync(buffer);
} }
@ -139,6 +141,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor * and receive result from this actor
*/ */
public byte[] onWriterMessageSync(byte[] buffer) { public byte[] onWriterMessageSync(byte[] buffer) {
if (transferHandler == null) {
return NOT_READY_FLAG;
}
return transferHandler.onWriterMessageSync(buffer); return transferHandler.onWriterMessageSync(buffer);
} }
} }

View file

@ -1,7 +1,9 @@
package io.ray.streaming.runtime.worker.tasks; package io.ray.streaming.runtime.worker.tasks;
import io.ray.runtime.serializer.Serializer;
import io.ray.streaming.runtime.core.processor.Processor; import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.Message; import io.ray.streaming.runtime.transfer.Message;
import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.util.Config; import io.ray.streaming.util.Config;
@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask {
private volatile boolean running = true; private volatile boolean running = true;
private volatile boolean stopped = false; private volatile boolean stopped = false;
private long readTimeoutMillis; private long readTimeoutMillis;
private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) { public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
super(taskId, processor, streamWorker); super(taskId, processor, streamWorker);
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig() readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS)); .getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
javaSerializer = new JavaSerializer();
crossLangSerializer = new CrossLangSerializer();
} }
@Override @Override
@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask {
while (running) { while (running) {
Message item = reader.read(readTimeoutMillis); Message item = reader.read(readTimeoutMillis);
if (item != null) { if (item != null) {
byte[] bytes = new byte[item.body().remaining()]; byte[] bytes = new byte[item.body().remaining() - 1];
byte typeId = item.body().get();
item.body().get(bytes); item.body().get(bytes);
Object obj = Serializer.decode(bytes, Object.class); Object obj;
if (typeId == Serializer.JAVA_TYPE_ID) {
obj = javaSerializer.deserialize(bytes);
} else {
obj = crossLangSerializer.deserialize(bytes);
}
processor.process(obj); processor.process(obj);
} }
} }

View file

@ -26,7 +26,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
public abstract class StreamTask implements Runnable { public abstract class StreamTask implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
protected int taskId; protected int taskId;
@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable {
String queueSize = worker.getConfig() String queueSize = worker.getConfig()
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT); .getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
queueConf.put(Config.CHANNEL_SIZE, queueSize); queueConf.put(Config.CHANNEL_SIZE, queueSize);
String channelType = worker.getConfig() String channelType = Ray.getRuntimeContext().isSingleProcess() ?
.getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL); Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL;
queueConf.put(Config.CHANNEL_TYPE, channelType); queueConf.put(Config.CHANNEL_TYPE, channelType);
ExecutionGraph executionGraph = worker.getExecutionGraph(); ExecutionGraph executionGraph = worker.getExecutionGraph();
@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable {
LOG.info("Create DataWriter succeed."); LOG.info("Create DataWriter succeed.");
writers.put(edge, writer); writers.put(edge, writer);
Partition partition = edge.getPartition(); Partition partition = edge.getPartition();
collectors.add(new OutputCollector(channelIDs, writer, partition)); collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition));
} }
} }
@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable {
reader = new DataReader(channelIDs, inputActors, queueConf); reader = new DataReader(channelIDs, inputActors, queueConf);
} }
RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(), RuntimeContext runtimeContext = new RayRuntimeContext(
worker.getConfig(), executionNode.getParallelism()); worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
processor.open(collectors, runtimeContext); processor.open(collectors, runtimeContext);

View file

@ -24,11 +24,13 @@ public abstract class BaseUnitTest {
@BeforeMethod @BeforeMethod
public void testBegin(Method method) { public void testBegin(Method method) {
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>"); LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>",
method.getDeclaringClass(), method.getName());
} }
@AfterMethod @AfterMethod
public void testEnd(Method method) { public void testEnd(Method method) {
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>"); LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>",
method.getDeclaringClass(), method.getName());
} }
} }

View file

@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
public static JobGraph buildJobGraph() { public static JobGraph buildJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext(); StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext, DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c")); Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); StreamSink streamSink = dataStream.sink(x -> LOG.info(x));

View file

@ -0,0 +1,56 @@
package io.ray.streaming.runtime.demo;
import io.ray.api.Ray;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.MapFunction;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.runtime.BaseUnitTest;
import java.io.Serializable;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
public class HybridStreamTest extends BaseUnitTest implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class);
public static class Mapper1 implements MapFunction<Object, Object> {
@Override
public Object map(Object value) {
LOG.info("HybridStreamTest Mapper1 {}", value);
return value.toString();
}
}
public static class Filter1 implements FilterFunction<Object> {
@Override
public boolean filter(Object value) throws Exception {
LOG.info("HybridStreamTest Filter1 {}", value);
return !value.toString().contains("b");
}
}
@Test
public void testHybridDataStream() throws InterruptedException {
Ray.shutdown();
StreamingContext context = StreamingContext.buildContext();
DataStreamSource<String> streamSource =
DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c"));
streamSource
.map(x -> x + x)
.asPythonStream()
.map("ray.streaming.tests.test_hybrid_stream", "map_func1")
.filter("ray.streaming.tests.test_hybrid_stream", "filter_func1")
.asJavaStream()
.sink(x -> System.out.println("HybridStreamTest: " + x));
context.execute("HybridStreamTestJob");
TimeUnit.SECONDS.sleep(3);
context.stop();
LOG.info("HybridStreamTest succeed");
}
}

View file

@ -1,6 +1,7 @@
package io.ray.streaming.runtime.demo; package io.ray.streaming.runtime.demo;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.api.function.impl.ReduceFunction;
@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
@Test @Test
public void testWordCount() { public void testWordCount() {
Ray.shutdown();
StreamingContext streamingContext = StreamingContext.buildContext(); StreamingContext streamingContext = StreamingContext.buildContext();
Map<String, String> config = new HashMap<>(); Map<String, String> config = new HashMap<>();
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1"); config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config); streamingContext.withConfig(config);
List<String> text = new ArrayList<>(); List<String> text = new ArrayList<>();
text.add("hello world eagle eagle eagle"); text.add("hello world eagle eagle eagle");
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text); DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource streamSource
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> { .flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
String[] records = value.split(" "); String[] records = value.split(" ");
@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
} }
} }
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1)); Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
streamingContext.stop();
} }
private static class WordAndCount implements Serializable { private static class WordAndCount implements Serializable {

View file

@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python;
import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder; import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;

View file

@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest {
public JobGraph buildDataSyncPlan() { public JobGraph buildDataSyncPlan() {
StreamingContext streamingContext = StreamingContext.buildContext(); StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext, DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c")); Lists.newArrayList("a", "b", "c"));
DataStreamSink streamSink = dataStream.sink(LOGGER::info); DataStreamSink streamSink = dataStream.sink(LOGGER::info);
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));

View file

@ -0,0 +1,26 @@
package io.ray.streaming.runtime.serialization;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import org.apache.commons.lang3.builder.EqualsBuilder;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import org.testng.annotations.Test;
public class CrossLangSerializerTest {
@Test
@SuppressWarnings("unchecked")
public void testSerialize() {
CrossLangSerializer serializer = new CrossLangSerializer();
Record record = new Record("value");
record.setStream("stream1");
assertTrue(EqualsBuilder.reflectionEquals(record,
serializer.deserialize(serializer.serialize(record))));
KeyRecord keyRecord = new KeyRecord("key", "value");
keyRecord.setStream("stream2");
assertEquals(keyRecord,
serializer.deserialize(serializer.serialize(keyRecord)));
}
}

View file

@ -1,4 +1,7 @@
package io.ray.streaming.runtime.python; package io.ray.streaming.runtime.serialization;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -6,25 +9,37 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public class MsgPackSerializerTest { public class MsgPackSerializerTest {
@Test
public void testSerializeByte() {
MsgPackSerializer serializer = new MsgPackSerializer();
assertEquals(serializer.deserialize(
serializer.serialize((byte)1)), (byte)1);
}
@Test @Test
public void testSerialize() { public void testSerialize() {
MsgPackSerializer serializer = new MsgPackSerializer(); MsgPackSerializer serializer = new MsgPackSerializer();
assertEquals(serializer.deserialize
(serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE);
assertEquals(serializer.deserialize(
serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE);
assertEquals(serializer.deserialize(
serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE);
Map map = new HashMap(); Map map = new HashMap();
List list = new ArrayList<>(); List list = new ArrayList<>();
list.add(null); list.add(null);
list.add(true); list.add(true);
list.add(1);
list.add(1.0d); list.add(1.0d);
list.add("str"); list.add("str");
map.put("k1", "value1"); map.put("k1", "value1");
map.put("k2", 2); map.put("k2", new HashMap<>());
map.put("k3", list); map.put("k3", list);
byte[] bytes = serializer.serialize(map); byte[] bytes = serializer.serialize(map);
Object o = serializer.deserialize(bytes); Object o = serializer.deserialize(bytes);

View file

@ -5,6 +5,7 @@ import io.ray.api.Ray;
import io.ray.api.RayActor; import io.ray.api.RayActor;
import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.ActorCreationOptions.Builder; import io.ray.api.options.ActorCreationOptions.Builder;
import io.ray.runtime.config.RayConfig;
import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.api.function.impl.ReduceFunction;
@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER"); System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true"); System.setProperty("ray.redirect-output", "true");
// ray init RayConfig.reset();
Ray.init(); Ray.init();
} }
@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000) @Test(timeOut = 60000)
public void testWordCount() { public void testWordCount() {
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true");
// ray init
Ray.init();
LOGGER.info("testWordCount"); LOGGER.info("testWordCount");
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}", LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
System.getProperty("ray.run-mode")); System.getProperty("ray.run-mode"));
@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config); streamingContext.withConfig(config);
List<String> text = new ArrayList<>(); List<String> text = new ArrayList<>();
text.add("hello world eagle eagle eagle"); text.add("hello world eagle eagle eagle");
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text); DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource streamSource
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> { .flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
String[] records = value.split(" "); String[] records = value.split(" ");
@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
serializeResultToFile(resultFile, wordCount); serializeResultToFile(resultFile, wordCount);
}); });
streamingContext.execute("testWordCount"); streamingContext.execute("testSQWordCount");
Map<String, Integer> checkWordCount = Map<String, Integer> checkWordCount =
(Map<String, Integer>) deserializeResultFromFile(resultFile); (Map<String, Integer>) deserializeResultFromFile(resultFile);

View file

@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on
echo "Running streaming tests." echo "Running streaming tests."
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\ java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml ||
exit_code=$? exit_code=$?
if [ -z ${exit_code+x} ]; then
exit_code=0
fi
echo "Streaming TestNG results" echo "Streaming TestNG results"
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
cat /tmp/ray_streaming_java_test_output/testng-results.xml cat /tmp/ray_streaming_java_test_output/testng-results.xml

View file

@ -1,10 +1,13 @@
import logging import logging
import pickle
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ray import Language
from ray.actor import ActorHandle
from ray.streaming import function
from ray.streaming import message from ray.streaming import message
from ray.streaming import partition from ray.streaming import partition
from ray.streaming.runtime import serialization
from ray.streaming.runtime.transfer import ChannelID, DataWriter from ray.streaming.runtime.transfer import ChannelID, DataWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,19 +34,46 @@ class CollectionCollector(Collector):
class OutputCollector(Collector): class OutputCollector(Collector):
def __init__(self, channel_ids: typing.List[str], writer: DataWriter, def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
target_actors: typing.List[ActorHandle],
partition_func: partition.Partition): partition_func: partition.Partition):
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._writer = writer self._writer = writer
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._target_languages = []
for actor in target_actors:
if actor._ray_actor_language == Language.PYTHON:
self._target_languages.append(function.Language.PYTHON)
elif actor._ray_actor_language == Language.JAVA:
self._target_languages.append(function.Language.JAVA)
else:
raise Exception("Unsupported language {}"
.format(actor._ray_actor_language))
self._partition_func = partition_func self._partition_func = partition_func
self.python_serializer = serialization.PythonSerializer()
self.cross_lang_serializer = serialization.CrossLangSerializer()
logger.info( logger.info(
"Create OutputCollector, channel_ids {}, partition_func {}".format( "Create OutputCollector, channel_ids {}, partition_func {}".format(
channel_ids, partition_func)) channel_ids, partition_func))
def collect(self, record): def collect(self, record):
partitions = self._partition_func.partition(record, partitions = self._partition_func \
len(self._channel_ids)) .partition(record, len(self._channel_ids))
serialized_message = pickle.dumps(record) python_buffer = None
cross_lang_buffer = None
for partition_index in partitions: for partition_index in partitions:
self._writer.write(self._channel_ids[partition_index], if self._target_languages[partition_index] == \
serialized_message) function.Language.PYTHON:
# avoid repeated serialization
if python_buffer is None:
python_buffer = self.python_serializer.serialize(record)
self._writer.write(
self._channel_ids[partition_index],
serialization._PYTHON_TYPE_ID + python_buffer)
else:
# avoid repeated serialization
if cross_lang_buffer is None:
cross_lang_buffer = self.cross_lang_serializer.serialize(
record)
self._writer.write(
self._channel_ids[partition_index],
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)

View file

@ -1,4 +1,4 @@
from abc import ABC from abc import ABC, abstractmethod
from ray.streaming import function from ray.streaming import function
from ray.streaming import partition from ray.streaming import partition
@ -19,7 +19,6 @@ class Stream(ABC):
self.streaming_context = input_stream.streaming_context self.streaming_context = input_stream.streaming_context
else: else:
self.streaming_context = streaming_context self.streaming_context = streaming_context
self.parallelism = 1
def get_streaming_context(self): def get_streaming_context(self):
return self.streaming_context return self.streaming_context
@ -29,7 +28,8 @@ class Stream(ABC):
Returns: Returns:
the parallelism of this transformation the parallelism of this transformation
""" """
return self.parallelism return self._gateway_client(). \
call_method(self._j_stream, "getParallelism")
def set_parallelism(self, parallelism: int): def set_parallelism(self, parallelism: int):
"""Sets the parallelism of this transformation """Sets the parallelism of this transformation
@ -40,7 +40,6 @@ class Stream(ABC):
Returns: Returns:
self self
""" """
self.parallelism = parallelism
self._gateway_client(). \ self._gateway_client(). \
call_method(self._j_stream, "setParallelism", parallelism) call_method(self._j_stream, "setParallelism", parallelism)
return self return self
@ -60,6 +59,10 @@ class Stream(ABC):
return self._gateway_client(). \ return self._gateway_client(). \
call_method(self._j_stream, "getId") call_method(self._j_stream, "getId")
@abstractmethod
def get_language(self):
pass
def _gateway_client(self): def _gateway_client(self):
return self.get_streaming_context()._gateway_client return self.get_streaming_context()._gateway_client
@ -75,6 +78,9 @@ class DataStream(Stream):
super().__init__( super().__init__(
input_stream, j_stream, streaming_context=streaming_context) input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.PYTHON
def map(self, func): def map(self, func):
""" """
Applies a Map transformation on a :class:`DataStream`. Applies a Map transformation on a :class:`DataStream`.
@ -158,6 +164,7 @@ class DataStream(Stream):
Returns: Returns:
A KeyDataStream A KeyDataStream
""" """
self._check_partition_call()
if not isinstance(func, function.KeyFunction): if not isinstance(func, function.KeyFunction):
func = function.SimpleKeyFunction(func) func = function.SimpleKeyFunction(func)
j_func = self._gateway_client().create_py_func( j_func = self._gateway_client().create_py_func(
@ -175,6 +182,7 @@ class DataStream(Stream):
Returns: Returns:
The DataStream with broadcast partitioning set. The DataStream with broadcast partitioning set.
""" """
self._check_partition_call()
self._gateway_client().call_method(self._j_stream, "broadcast") self._gateway_client().call_method(self._j_stream, "broadcast")
return self return self
@ -191,6 +199,7 @@ class DataStream(Stream):
Returns: Returns:
The DataStream with specified partitioning set. The DataStream with specified partitioning set.
""" """
self._check_partition_call()
if not isinstance(partition_func, partition.Partition): if not isinstance(partition_func, partition.Partition):
partition_func = partition.SimplePartition(partition_func) partition_func = partition.SimplePartition(partition_func)
j_partition = self._gateway_client().create_py_func( j_partition = self._gateway_client().create_py_func(
@ -199,6 +208,16 @@ class DataStream(Stream):
call_method(self._j_stream, "partitionBy", j_partition) call_method(self._j_stream, "partitionBy", j_partition)
return self return self
def _check_partition_call(self):
"""
If parent stream is a java stream, we can't call partition related
methods in the python stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.JAVA:
raise Exception("Partition related methods can't be called on a "
"python stream if parent stream is a java stream.")
def sink(self, func): def sink(self, func):
""" """
Create a StreamSink with the given sink. Create a StreamSink with the given sink.
@ -217,8 +236,97 @@ class DataStream(Stream):
call_method(self._j_stream, "sink", j_func) call_method(self._j_stream, "sink", j_func)
return StreamSink(self, j_stream, func) return StreamSink(self, j_stream, func)
def as_java_stream(self):
"""
Convert this stream as a java JavaDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaDataStream(self, j_stream)
class KeyDataStream(Stream):
class JavaDataStream(Stream):
"""
Represents a stream of data which applies a transformation executed by
java. It's also a wrapper of java
`org.ray.streaming.api.stream.DataStream`
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
def map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.map"""
return JavaDataStream(self, self._unary_call("map", java_func_class))
def flat_map(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
return JavaDataStream(self, self._unary_call("flatMap",
java_func_class))
def filter(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.filter"""
return JavaDataStream(self, self._unary_call("filter",
java_func_class))
def key_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
self._check_partition_call()
return JavaKeyDataStream(self,
self._unary_call("keyBy", java_func_class))
def broadcast(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("broadcast", java_func_class))
def partition_by(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
self._check_partition_call()
return JavaDataStream(self,
self._unary_call("partitionBy", java_func_class))
def sink(self, java_func_class):
"""See org.ray.streaming.api.stream.DataStream.sink"""
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python DataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return DataStream(self, j_stream)
def _check_partition_call(self):
"""
If parent stream is a python stream, we can't call partition related
methods in the java stream
"""
if self.input_stream is not None and \
self.input_stream.get_language() == function.Language.PYTHON:
raise Exception("Partition related methods can't be called on a"
"java stream if parent stream is a python stream.")
def _unary_call(self, func_name, java_func_class):
j_func = self._gateway_client().new_instance(java_func_class)
j_stream = self._gateway_client(). \
call_method(self._j_stream, func_name, j_func)
return j_stream
class KeyDataStream(DataStream):
"""Represents a DataStream returned by a key-by operation. """Represents a DataStream returned by a key-by operation.
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
""" """
@ -251,6 +359,43 @@ class KeyDataStream(Stream):
call_method(self._j_stream, "reduce", j_func) call_method(self._j_stream, "reduce", j_func)
return DataStream(self, j_stream) return DataStream(self, j_stream)
def as_java_stream(self):
"""
Convert this stream as a java KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asJavaStream")
return JavaKeyDataStream(self, j_stream)
class JavaKeyDataStream(JavaDataStream):
"""
Represents a DataStream returned by a key-by operation in java.
Wrapper of org.ray.streaming.api.stream.KeyDataStream
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def reduce(self, java_func_class):
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
return JavaDataStream(self,
super()._unary_call("reduce", java_func_class))
def as_python_stream(self):
"""
Convert this stream as a python KeyDataStream.
The converted stream and this stream are the same logical stream,
which has same stream id. Changes in converted stream will be reflected
in this stream and vice versa.
"""
j_stream = self._gateway_client(). \
call_method(self._j_stream, "asPythonStream")
return KeyDataStream(self, j_stream)
class StreamSource(DataStream): class StreamSource(DataStream):
"""Represents a source of the DataStream. """Represents a source of the DataStream.
@ -261,9 +406,12 @@ class StreamSource(DataStream):
super().__init__(None, j_stream, streaming_context=streaming_context) super().__init__(None, j_stream, streaming_context=streaming_context)
self.source_func = source_func self.source_func = source_func
def get_language(self):
return function.Language.PYTHON
@staticmethod @staticmethod
def build_source(streaming_context, func): def build_source(streaming_context, func):
"""Build a StreamSource source from a collection. """Build a StreamSource source from a source function.
Args: Args:
streaming_context: Stream context streaming_context: Stream context
func: A instance of `SourceFunction` func: A instance of `SourceFunction`
@ -275,6 +423,34 @@ class StreamSource(DataStream):
return StreamSource(j_stream, streaming_context, func) return StreamSource(j_stream, streaming_context, func)
class JavaStreamSource(JavaDataStream):
"""Represents a source of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
"""
def __init__(self, j_stream, streaming_context):
super().__init__(None, j_stream, streaming_context=streaming_context)
def get_language(self):
return function.Language.JAVA
@staticmethod
def build_source(streaming_context, java_source_func_class):
"""Build a java StreamSource source from a java source function.
Args:
streaming_context: Stream context
java_source_func_class: qualified class name of java SourceFunction
Returns:
A java StreamSource
"""
j_func = streaming_context._gateway_client() \
.new_instance(java_source_func_class)
j_stream = streaming_context._gateway_client() \
.call_function("org.ray.streaming.api.stream.DataStreamSource"
"fromSource", streaming_context._j_ctx, j_func)
return JavaStreamSource(j_stream, streaming_context)
class StreamSink(Stream): class StreamSink(Stream):
"""Represents a sink of the DataStream. """Represents a sink of the DataStream.
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
@ -282,3 +458,18 @@ class StreamSink(Stream):
def __init__(self, input_stream, j_stream, func): def __init__(self, input_stream, j_stream, func):
super().__init__(input_stream, j_stream) super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.PYTHON
class JavaStreamSink(Stream):
"""Represents a sink of the java DataStream.
Wrapper of java org.ray.streaming.api.stream.StreamSink
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def get_language(self):
return function.Language.JAVA

View file

@ -1,13 +1,19 @@
import enum
import importlib import importlib
import inspect import inspect
import sys import sys
from abc import ABC, abstractmethod
import typing import typing
from abc import ABC, abstractmethod
from ray import cloudpickle from ray import cloudpickle
from ray.streaming.runtime import gateway_client from ray.streaming.runtime import gateway_client
class Language(enum.Enum):
JAVA = 0
PYTHON = 1
class Function(ABC): class Function(ABC):
"""The base interface for all user-defined functions.""" """The base interface for all user-defined functions."""
@ -60,6 +66,7 @@ class MapFunction(Function):
for each input element. for each input element.
""" """
@abstractmethod
def map(self, value): def map(self, value):
pass pass
@ -70,6 +77,7 @@ class FlatMapFunction(Function):
transform them into zero, one, or more elements. transform them into zero, one, or more elements.
""" """
@abstractmethod
def flat_map(self, value, collector): def flat_map(self, value, collector):
"""Takes an element from the input data set and transforms it into zero, """Takes an element from the input data set and transforms it into zero,
one, or more elements. one, or more elements.
@ -87,6 +95,7 @@ class FilterFunction(Function):
The predicate decides whether to keep the element, or to discard it. The predicate decides whether to keep the element, or to discard it.
""" """
@abstractmethod
def filter(self, value): def filter(self, value):
"""The filter function that evaluates the predicate. """The filter function that evaluates the predicate.
@ -106,6 +115,7 @@ class KeyFunction(Function):
deterministic key for that object. deterministic key for that object.
""" """
@abstractmethod
def key_by(self, value): def key_by(self, value):
"""User-defined function that deterministically extracts the key from """User-defined function that deterministically extracts the key from
an object. an object.
@ -126,6 +136,7 @@ class ReduceFunction(Function):
them into one. them into one.
""" """
@abstractmethod
def reduce(self, old_value, new_value): def reduce(self, old_value, new_value):
""" """
The core method of ReduceFunction, combining two values into one value The core method of ReduceFunction, combining two values into one value
@ -145,6 +156,7 @@ class ReduceFunction(Function):
class SinkFunction(Function): class SinkFunction(Function):
"""Interface for implementing user defined sink functionality.""" """Interface for implementing user defined sink functionality."""
@abstractmethod
def sink(self, value): def sink(self, value):
"""Writes the given value to the sink. This function is called for """Writes the given value to the sink. This function is called for
every record.""" every record."""
@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
Returns: Returns:
a streaming function a streaming function
""" """
function_bytes, module_name, class_name, function_name, function_interface\ assert len(descriptor_func_bytes) > 0
function_bytes, module_name, function_name, function_interface\
= gateway_client.deserialize(descriptor_func_bytes) = gateway_client.deserialize(descriptor_func_bytes)
if function_bytes: if function_bytes:
return deserialize(function_bytes) return deserialize(function_bytes)
@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
assert function_interface assert function_interface
function_interface = getattr(sys.modules[__name__], function_interface) function_interface = getattr(sys.modules[__name__], function_interface)
mod = importlib.import_module(module_name) mod = importlib.import_module(module_name)
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
assert issubclass(cls, function_interface)
return cls()
else:
assert function_name assert function_name
func = getattr(mod, function_name) func = getattr(mod, function_name)
# If func is a python function, user function is a simple python
# function, which will be wrapped as a SimpleXXXFunction.
# If func is a python class, user function is a sub class
# of XXXFunction.
if inspect.isfunction(func):
simple_func_class = _get_simple_function_class(function_interface) simple_func_class = _get_simple_function_class(function_interface)
return simple_func_class(func) return simple_func_class(func)
else:
assert issubclass(func, function_interface)
return func()
def _get_simple_function_class(function_interface): def _get_simple_function_class(function_interface):

View file

@ -8,6 +8,14 @@ class Record:
def __repr__(self): def __repr__(self):
return "Record(%s)".format(self.value) return "Record(%s)".format(self.value)
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.value) == (other.stream, other.value)
return False
def __hash__(self):
return hash((self.stream, self.value))
class KeyRecord(Record): class KeyRecord(Record):
"""Data record in a keyed data stream""" """Data record in a keyed data stream"""
@ -15,3 +23,12 @@ class KeyRecord(Record):
def __init__(self, key, value): def __init__(self, key, value):
super().__init__(value) super().__init__(value)
self.key = key self.key = key
def __eq__(self, other):
if type(self) is type(other):
return (self.stream, self.key, self.value) ==\
(other.stream, other.key, other.value)
return False
def __hash__(self):
return hash((self.stream, self.key, self.value))

View file

@ -1,4 +1,5 @@
import importlib import importlib
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ray import cloudpickle from ray import cloudpickle
@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
Returns: Returns:
partition function partition function
""" """
partition_bytes, module_name, class_name, function_name =\ assert len(descriptor_partition_bytes) > 0
partition_bytes, module_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes) gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes: if partition_bytes:
return deserialize(partition_bytes) return deserialize(partition_bytes)
else: else:
assert module_name assert module_name
mod = importlib.import_module(module_name) mod = importlib.import_module(module_name)
# If class_name is not None, user partition is a sub class
# of Partition.
# If function_name is not None, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
return cls()
else:
assert function_name assert function_name
func = getattr(mod, function_name) func = getattr(mod, function_name)
# If func is a python function, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
# If func is a python class, user partition is a sub class
# of Partition.
if inspect.isfunction(func):
return SimplePartition(func) return SimplePartition(func)
else:
assert issubclass(func, Partition)
return func()

View file

@ -55,6 +55,11 @@ class GatewayClient:
call = self._python_gateway_actor.callMethod.remote(java_params) call = self._python_gateway_actor.callMethod.remote(java_params)
return deserialize(ray.get(call)) return deserialize(ray.get(call))
def new_instance(self, java_class_name):
call = self._python_gateway_actor.newInstance.remote(
serialize(java_class_name))
return deserialize(ray.get(call))
def serialize(obj) -> bytes: def serialize(obj) -> bytes:
"""Serialize a python object which can be deserialized by `PythonGateway` """Serialize a python object which can be deserialized by `PythonGateway`

View file

@ -53,7 +53,9 @@ class ExecutionEdge:
self.src_node_id = edge_pb.src_node_id self.src_node_id = edge_pb.src_node_id
self.target_node_id = edge_pb.target_node_id self.target_node_id = edge_pb.target_node_id
partition_bytes = edge_pb.partition partition_bytes = edge_pb.partition
if language == Language.PYTHON: # Sink node doesn't have partition function,
# so we only deserialize partition_bytes when it's not None or empty
if language == Language.PYTHON and partition_bytes:
self.partition = partition.load_partition(partition_bytes) self.partition = partition.load_partition(partition_bytes)

View file

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
import pickle
import msgpack
from ray.streaming import message
_RECORD_TYPE_ID = 0
_KEY_RECORD_TYPE_ID = 1
_CROSS_LANG_TYPE_ID = b"0"
_JAVA_TYPE_ID = b"1"
_PYTHON_TYPE_ID = b"2"
class Serializer(ABC):
@abstractmethod
def serialize(self, obj):
pass
@abstractmethod
def deserialize(self, serialized_bytes):
pass
class PythonSerializer(Serializer):
def serialize(self, obj):
return pickle.dumps(obj)
def deserialize(self, serialized_bytes):
return pickle.loads(serialized_bytes)
class CrossLangSerializer(Serializer):
"""Serialize stream element between java/python"""
def serialize(self, obj):
if type(obj) is message.Record:
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
elif type(obj) is message.KeyRecord:
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
else:
raise Exception("Unsupported value {}".format(obj))
return msgpack.packb(fields, use_bin_type=True)
def deserialize(self, data):
fields = msgpack.unpackb(data, raw=False)
if fields[0] == _RECORD_TYPE_ID:
stream, value = fields[1:]
record = message.Record(value)
record.stream = stream
return record
elif fields[0] == _KEY_RECORD_TYPE_ID:
stream, key, value = fields[1:]
key_record = message.KeyRecord(key, value)
key_record.stream = stream
return key_record
else:
raise Exception("Unsupported type id {}, type {}".format(
fields[0], type(fields[0])))

View file

@ -1,11 +1,13 @@
import logging import logging
import pickle
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime import serialization
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,36 +40,40 @@ class StreamTask(ABC):
# writers # writers
collectors = [] collectors = []
for edge in execution_node.output_edges: for edge in execution_node.output_edges:
output_actor_ids = {} output_actors_map = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id) edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items(): for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id, channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time()) execution_graph.build_time())
output_actor_ids[channel_name] = target_actor output_actors_map[channel_name] = target_actor
if len(output_actor_ids) > 0: if len(output_actors_map) > 0:
channel_ids = list(output_actor_ids.keys()) channel_ids = list(output_actors_map.keys())
to_actor_ids = list(output_actor_ids.values()) target_actors = list(output_actors_map.values())
writer = DataWriter(channel_ids, to_actor_ids, channel_conf) logger.info(
logger.info("Create DataWriter succeed.") "Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer self.writers[edge] = writer
collectors.append( collectors.append(
OutputCollector(channel_ids, writer, edge.partition)) OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers # readers
input_actor_ids = {} input_actor_map = {}
for edge in execution_node.input_edges: for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id) edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items(): for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id, channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time()) execution_graph.build_time())
input_actor_ids[channel_name] = src_actor input_actor_map[channel_name] = src_actor
if len(input_actor_ids) > 0: if len(input_actor_map) > 0:
channel_ids = list(input_actor_ids.keys()) channel_ids = list(input_actor_map.keys())
from_actor_ids = list(input_actor_ids.values()) from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {}.".format(channel_ids)) logger.info("Create DataReader, channels {}, input_actors {}."
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf) .format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler(): def exit_handler():
# Make DataReader stop read data when MockQueue destructor # Make DataReader stop read data when MockQueue destructor
@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
self.read_timeout_millis = \ self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS, int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS)) Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self): def init(self):
pass pass
@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
item = self.reader.read(self.read_timeout_millis) item = self.reader.read(self.read_timeout_millis)
if item is not None: if item is not None:
msg_data = item.body() msg_data = item.body()
msg = pickle.loads(msg_data) type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg) self.processor.process(msg)
self.stopped = True self.stopped = True

View file

@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
wrap initial parameters needed by a streaming queue wrap initial parameters needed by a streaming queue
""" """
_java_reader_async_function_descriptor = JavaFunctionDescriptor( _java_reader_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V") "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
"([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor( _java_reader_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B") "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
"([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor( _java_writer_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V") "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
"([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor( _java_writer_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B") "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
"([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor( _python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker") "ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor( _python_reader_sync_function_descriptor = PythonFunctionDescriptor(

View file

@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# special flag to indicate this actor not ready
_NOT_READY_FLAG_ = b" " * 4
@ray.remote @ray.remote
class JobWorker(object): class JobWorker(object):
@ -66,23 +69,31 @@ class JobWorker(object):
type(self.stream_processor)) type(self.stream_processor))
def on_reader_message(self, buffer: bytes): def on_reader_message(self, buffer: bytes):
"""used in direct call mode""" """Called by upstream queue writer to send data message to downstream
queue reader.
"""
self.reader_client.on_reader_message(buffer) self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes): def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode""" """Called by upstream queue writer to send control message to downstream
downstream queue reader.
"""
if self.reader_client is None: if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer) result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes() return result.to_pybytes()
def on_writer_message(self, buffer: bytes): def on_writer_message(self, buffer: bytes):
"""used in direct call mode""" """Called by downstream queue reader to send notify message to
upstream queue writer.
"""
self.writer_client.on_writer_message(buffer) self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes): def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode""" """Called by downstream queue reader to send control message to
upstream queue writer.
"""
if self.writer_client is None: if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer) result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes() return result.to_pybytes()

View file

@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
def test_load_function(): def test_load_function():
# function_bytes, module_name, class_name, function_name, # function_bytes, module_name, function_name/class_name,
# function_interface # function_interface
descriptor_func_bytes = gateway_client.serialize( descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, None, "MapFunction"]) [None, __name__, MapFunc.__name__, "MapFunction"])
func = function.load_function(descriptor_func_bytes) func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc assert type(func) is MapFunc

View file

@ -0,0 +1,70 @@
import json
import ray
from ray.streaming import StreamingContext
import subprocess
import os
def map_func1(x):
print("HybridStreamTest map_func1", x)
return str(x)
def filter_func1(x):
print("HybridStreamTest filter_func1", x)
return "b" not in x
def sink_func1(x):
print("HybridStreamTest sink_func1 value:", x)
def test_hybrid_stream():
subprocess.check_call(
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
current_dir = os.path.abspath(os.path.dirname(__file__))
jar_path = os.path.join(
current_dir,
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
jar_path = os.path.abspath(jar_path)
print("jar_path", jar_path)
java_worker_options = json.dumps(["-classpath", jar_path])
print("java_worker_options", java_worker_options)
assert not ray.is_initialized()
ray.init(
load_code_from_local=True,
include_java=True,
java_worker_options=java_worker_options,
_internal_config=json.dumps({
"num_workers_per_process_java": 1
}))
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
if os.path.exists(sink_file):
os.remove(sink_file)
def sink_func(x):
print("HybridStreamTest", x)
with open(sink_file, "a") as f:
f.write(str(x))
ctx = StreamingContext.Builder().build()
ctx.from_values("a", "b", "c") \
.as_java_stream() \
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
.as_python_stream() \
.sink(sink_func)
ctx.submit("HybridStreamTest")
import time
time.sleep(3)
ray.shutdown()
with open(sink_file, "r") as f:
result = f.read()
assert "a" in result
assert "b" not in result
assert "c" in result
if __name__ == "__main__":
test_hybrid_stream()

View file

@ -0,0 +1,13 @@
from ray.streaming.runtime.serialization import CrossLangSerializer
from ray.streaming.message import Record, KeyRecord
def test_serialize():
serializer = CrossLangSerializer()
record = Record("value")
record.stream = "stream1"
key_record = KeyRecord("key", "value")
key_record.stream = "stream2"
assert record == serializer.deserialize(serializer.serialize(record))
assert key_record == serializer.\
deserialize(serializer.serialize(key_record))

View file

@ -0,0 +1,31 @@
import ray
from ray.streaming import StreamingContext
def test_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
stream = ctx.from_values(1, 2, 3)
java_stream = stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert stream.get_id() == java_stream.get_id()
assert stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert stream.get_parallelism() == java_stream.get_parallelism()
assert stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()
def test_key_data_stream():
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder().build()
key_stream = ctx.from_values(
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
java_stream = key_stream.as_java_stream()
python_stream = java_stream.as_python_stream()
assert key_stream.get_id() == java_stream.get_id()
assert key_stream.get_id() == python_stream.get_id()
python_stream.set_parallelism(10)
assert key_stream.get_parallelism() == java_stream.get_parallelism()
assert key_stream.get_parallelism() == python_stream.get_parallelism()
ray.shutdown()

View file

@ -32,7 +32,9 @@ def test_simple_word_count():
def sink_func(x): def sink_func(x):
with open(sink_file, "a") as f: with open(sink_file, "a") as f:
f.write("{}:{},".format(x[0], x[1])) line = "{}:{},".format(x[0], x[1])
print("sink_func", line)
f.write(line)
ctx.from_values("a", "b", "c") \ ctx.from_values("a", "b", "c") \
.set_parallelism(1) \ .set_parallelism(1) \

View file

@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
return reinterpret_cast<jlong>(reader_client); return reinterpret_cast<jlong>(reader_client);
} }
JNIEXPORT void JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
}
JNIEXPORT jbyteArray JNICALL JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {