From 91f630f70959098facef4ab4e1a98d1f48f9a586 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 29 Apr 2020 13:42:08 +0800 Subject: [PATCH] [Streaming] Streaming Cross-Lang API (#7464) --- python/ray/worker.py | 6 + streaming/java/BUILD.bazel | 11 +- streaming/java/dependencies.bzl | 10 +- streaming/java/streaming-api/pom.xml | 20 ++ streaming/java/streaming-api/pom_template.xml | 5 + .../streaming/api/context/ClusterStarter.java | 129 +++++++++++ .../api/context/StreamingContext.java | 25 +++ .../ray/streaming/api/stream/DataStream.java | 77 +++++-- .../streaming/api/stream/DataStreamSink.java | 11 +- .../api/stream/DataStreamSource.java | 23 +- .../streaming/api/stream/KeyDataStream.java | 26 ++- .../io/ray/streaming/api/stream/Stream.java | 144 +++++++++---- .../ray/streaming/api/stream/StreamSink.java | 4 +- .../ray/streaming/api/stream/UnionStream.java | 6 +- .../io/ray/streaming/jobgraph/JobGraph.java | 11 + .../streaming/jobgraph/JobGraphBuilder.java | 19 +- .../io/ray/streaming/message/KeyRecord.java | 22 ++ .../io/ray/streaming/message/Message.java | 64 ------ .../java/io/ray/streaming/message/Record.java | 19 ++ .../ray/streaming/python/PythonFunction.java | 77 +++---- .../ray/streaming/python/PythonPartition.java | 27 +-- .../python/stream/PythonDataStream.java | 82 +++++-- .../python/stream/PythonKeyDataStream.java | 30 ++- .../python/stream/PythonStreamSink.java | 9 +- .../python/stream/PythonStreamSource.java | 11 +- .../java/io/ray/streaming/util/Config.java | 1 - .../ray/streaming/api/stream/StreamTest.java | 40 ++++ .../jobgraph/JobGraphBuilderTest.java | 4 +- streaming/java/streaming-runtime/pom.xml | 15 +- .../core/collector/OutputCollector.java | 55 ++++- .../runtime/python/GraphPbBuilder.java | 10 +- .../runtime/python/PythonGateway.java | 86 ++++++-- .../runtime/schedule/JobSchedulerImpl.java | 6 +- .../runtime/schedule/TaskAssignerImpl.java | 19 +- .../serialization/CrossLangSerializer.java | 62 ++++++ .../runtime/serialization/JavaSerializer.java | 15 ++ .../MsgPackSerializer.java | 12 +- .../runtime/serialization/Serializer.java | 12 ++ .../ChannelCreationParametersBuilder.java | 10 +- .../runtime/transfer/DataReader.java | 2 +- .../runtime/transfer/DataWriter.java | 9 +- .../runtime/util/ReflectionUtils.java | 1 + .../streaming/runtime/worker/JobWorker.java | 21 +- .../runtime/worker/tasks/InputStreamTask.java | 18 +- .../runtime/worker/tasks/StreamTask.java | 11 +- .../ray/streaming/runtime/BaseUnitTest.java | 6 +- .../core/graph/ExecutionGraphTest.java | 2 +- .../runtime/demo/HybridStreamTest.java | 56 +++++ .../streaming/runtime/demo/WordCountTest.java | 5 +- .../runtime/python/PythonGatewayTest.java | 1 + .../schedule/TaskAssignerImplTest.java | 2 +- .../CrossLangSerializerTest.java | 26 +++ .../MsgPackSerializerTest.java | 25 ++- .../streamingqueue/StreamingQueueTest.java | 15 +- streaming/java/test.sh | 5 +- streaming/python/collector.py | 46 +++- streaming/python/datastream.py | 203 +++++++++++++++++- streaming/python/function.py | 35 ++- streaming/python/message.py | 17 ++ streaming/python/partition.py | 23 +- streaming/python/runtime/gateway_client.py | 5 + streaming/python/runtime/graph.py | 4 +- streaming/python/runtime/serialization.py | 57 +++++ streaming/python/runtime/task.py | 46 ++-- streaming/python/runtime/transfer.py | 12 +- streaming/python/runtime/worker.py | 23 +- streaming/python/tests/test_function.py | 4 +- streaming/python/tests/test_hybrid_stream.py | 70 ++++++ streaming/python/tests/test_serialization.py | 13 ++ streaming/python/tests/test_stream.py | 31 +++ streaming/python/tests/test_word_count.py | 4 +- ...eaming_runtime_transfer_TransferHandler.cc | 7 + 72 files changed, 1612 insertions(+), 408 deletions(-) create mode 100644 streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java delete mode 100644 streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java create mode 100644 streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java create mode 100644 streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java create mode 100644 streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java rename streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/{python => serialization}/MsgPackSerializer.java (90%) create mode 100644 streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java create mode 100644 streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java create mode 100644 streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java rename streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/{python => serialization}/MsgPackSerializerTest.java (59%) create mode 100644 streaming/python/runtime/serialization.py create mode 100644 streaming/python/tests/test_hybrid_stream.py create mode 100644 streaming/python/tests/test_serialization.py create mode 100644 streaming/python/tests/test_stream.py diff --git a/python/ray/worker.py b/python/ray/worker.py index f28ac6fbc..cec67dba6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -542,6 +542,7 @@ def init(address=None, raylet_socket_name=None, temp_dir=None, load_code_from_local=False, + java_worker_options=None, use_pickle=True, _internal_config=None, lru_evict=False): @@ -651,6 +652,7 @@ def init(address=None, conventional location, e.g., "/tmp/ray". load_code_from_local: Whether code should be loaded from a local module or from the GCS. + java_worker_options: Overwrite the options to start Java workers. use_pickle: Deprecated. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. @@ -758,6 +760,7 @@ def init(address=None, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir, load_code_from_local=load_code_from_local, + java_worker_options=java_worker_options, _internal_config=_internal_config, ) # Start the Ray processes. We set shutdown_at_exit=False because we @@ -808,6 +811,9 @@ def init(address=None, if raylet_socket_name is not None: raise ValueError("When connecting to an existing cluster, " "raylet_socket_name must not be provided.") + if java_worker_options is not None: + raise ValueError("When connecting to an existing cluster, " + "java_worker_options must not be provided.") if _internal_config is not None and len(_internal_config) != 0: raise ValueError("When connecting to an existing cluster, " "_internal_config must not be provided.") diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel index ccaefb3b0..8d66ebd91 100644 --- a/streaming/java/BUILD.bazel +++ b/streaming/java/BUILD.bazel @@ -39,6 +39,7 @@ define_java_module( ":io_ray_ray_streaming-state", ":io_ray_ray_streaming-api", "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", "@ray_streaming_maven//:org_testng_testng", @@ -46,7 +47,12 @@ define_java_module( visibility = ["//visibility:public"], deps = [ ":io_ray_ray_streaming-state", + "//java:io_ray_ray_api", + "//java:io_ray_ray_runtime", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", + "@ray_streaming_maven//:com_google_code_gson_gson", "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", ], @@ -129,8 +135,9 @@ define_java_module( ":io_ray_ray_streaming-api", ":io_ray_ray_streaming-runtime", "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", "@ray_streaming_maven//:de_ruedigermoeller_fst", - "@ray_streaming_maven//:org_msgpack_msgpack_core", "@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", @@ -146,10 +153,12 @@ define_java_module( "//java:io_ray_ray_api", "//java:io_ray_ray_runtime", "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", + "@ray_streaming_maven//:com_google_code_findbugs_jsr305", "@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_protobuf_protobuf_java", "@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:org_aeonbits_owner_owner", + "@ray_streaming_maven//:org_apache_commons_commons_lang3", "@ray_streaming_maven//:org_msgpack_msgpack_core", "@ray_streaming_maven//:org_slf4j_slf4j_api", "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index 40327336d..998d88434 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -6,8 +6,11 @@ def gen_streaming_java_deps(): artifacts = [ "com.beust:jcommander:1.72", "com.google.guava:guava:27.0.1-jre", + "com.google.code.findbugs:jsr305:3.0.2", + "com.google.code.gson:gson:2.8.5", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.protobuf:protobuf-java:3.8.0", + "org.apache.commons:commons-lang3:3.4", "de.ruedigermoeller:fst:2.57", "org.aeonbits.owner:owner:1.0.10", "org.slf4j:slf4j-api:1.7.12", @@ -19,10 +22,9 @@ def gen_streaming_java_deps(): "org.apache.commons:commons-lang3:3.3.2", "org.msgpack:msgpack-core:0.8.20", "org.testng:testng:6.9.10", - "org.mockito:mockito-all:1.10.19", - "org.powermock:powermock-module-testng:1.6.6", - "org.powermock:powermock-api-mockito:1.6.6", - "org.projectlombok:lombok:1.16.20", + "org.mockito:mockito-all:1.10.19", + "org.powermock:powermock-module-testng:1.6.6", + "org.powermock:powermock-api-mockito:1.6.6", ], repositories = [ "https://repo1.maven.org/maven2/", diff --git a/streaming/java/streaming-api/pom.xml b/streaming/java/streaming-api/pom.xml index 253f7a3b4..4e100fefd 100644 --- a/streaming/java/streaming-api/pom.xml +++ b/streaming/java/streaming-api/pom.xml @@ -22,16 +22,36 @@ ray-api ${project.version} + + io.ray + ray-runtime + ${project.version} + org.ray streaming-state ${project.version} + com.google.code.findbugs + jsr305 + 3.0.2 + + + com.google.code.gson + gson + 2.8.5 + + com.google.guava guava 27.0.1-jre + + org.apache.commons + commons-lang3 + 3.4 + org.slf4j slf4j-api diff --git a/streaming/java/streaming-api/pom_template.xml b/streaming/java/streaming-api/pom_template.xml index 7c7171cdc..9b94fb278 100644 --- a/streaming/java/streaming-api/pom_template.xml +++ b/streaming/java/streaming-api/pom_template.xml @@ -22,6 +22,11 @@ ray-api ${project.version} + + io.ray + ray-runtime + ${project.version} + org.ray streaming-state diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java new file mode 100644 index 000000000..0fe98798c --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java @@ -0,0 +1,129 @@ +package io.ray.streaming.api.context; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; +import io.ray.api.Ray; +import io.ray.runtime.config.RayConfig; +import io.ray.runtime.util.NetworkUtil; +import java.io.File; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class ClusterStarter { + private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class); + private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket"; + private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket"; + + static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) { + Preconditions.checkArgument(Ray.internal() == null); + RayConfig.reset(); + if (!isLocal) { + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + System.setProperty("ray.run-mode", "CLUSTER"); + } else { + System.clearProperty("ray.raylet.config.num_workers_per_process_java"); + System.setProperty("ray.run-mode", "SINGLE_PROCESS"); + } + + if (!isCrossLanguage) { + Ray.init(); + return; + } + + // Delete existing socket files. + for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) { + File file = new File(socket); + if (file.exists()) { + LOG.info("Delete existing socket file {}", file); + file.delete(); + } + } + + String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort()); + + // jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly. + // Since mvn test classes contains `test` in path and bazel test classes is located at a jar + // with `test` included in the name, we can check classpath `test` to filter out test classes. + String classpath = Stream.of(System.getProperty("java.class.path").split(":")) + .filter(s -> !s.contains(" ") && s.contains("test")) + .collect(Collectors.joining(":")); + String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath)); + Map config = new HashMap<>(RayConfig.create().rayletConfigParameters); + config.put("num_workers_per_process_java", "1"); + // Start ray cluster. + List startCommand = ImmutableList.of( + "ray", + "start", + "--head", + "--redis-port=6379", + String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME), + String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME), + String.format("--node-manager-port=%s", nodeManagerPort), + "--load-code-from-local", + "--include-java", + "--java-worker-options=" + workerOptions, + "--internal-config=" + new Gson().toJson(config) + ); + if (!executeCommand(startCommand, 10)) { + throw new RuntimeException("Couldn't start ray cluster."); + } + + // Connect to the cluster. + System.setProperty("ray.redis.address", "127.0.0.1:6379"); + System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME); + System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); + System.setProperty("ray.raylet.node-manager-port", nodeManagerPort); + Ray.init(); + } + + public static synchronized void stopCluster(boolean isCrossLanguage) { + // Disconnect to the cluster. + Ray.shutdown(); + System.clearProperty("ray.redis.address"); + System.clearProperty("ray.object-store.socket-name"); + System.clearProperty("ray.raylet.socket-name"); + System.clearProperty("ray.raylet.node-manager-port"); + System.clearProperty("ray.raylet.config.num_workers_per_process_java"); + System.clearProperty("ray.run-mode"); + + if (isCrossLanguage) { + // Stop ray cluster. + final List stopCommand = ImmutableList.of( + "ray", + "stop" + ); + if (!executeCommand(stopCommand, 10)) { + throw new RuntimeException("Couldn't stop ray cluster"); + } + } + } + + /** + * Execute an external command. + * + * @return Whether the command succeeded. + */ + private static boolean executeCommand(List command, int waitTimeoutSeconds) { + LOG.info("Executing command: {}", String.join(" ", command)); + try { + ProcessBuilder processBuilder = new ProcessBuilder(command) + .redirectOutput(ProcessBuilder.Redirect.INHERIT) + .redirectError(ProcessBuilder.Redirect.INHERIT); + Process process = processBuilder.start(); + boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS); + if (!exit) { + process.destroyForcibly(); + } + return process.exitValue() == 0; + } catch (Exception e) { + throw new RuntimeException("Error executing command " + String.join(" ", command), e); + } + } +} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java index edf2fcd50..5f1ab4d4e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java @@ -1,10 +1,12 @@ package io.ray.streaming.api.context; import com.google.common.base.Preconditions; +import io.ray.api.Ray; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraphBuilder; import io.ray.streaming.schedule.JobScheduler; +import io.ray.streaming.util.Config; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -13,11 +15,14 @@ import java.util.List; import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Encapsulate the context information of a streaming Job. */ public class StreamingContext implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class); private transient AtomicInteger idGenerator; @@ -54,6 +59,20 @@ public class StreamingContext implements Serializable { this.jobGraph = jobGraphBuilder.build(); jobGraph.printJobGraph(); + if (Ray.internal() == null) { + if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) { + Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph()); + ClusterStarter.startCluster(false, true); + LOG.info("Created local cluster for job {}.", jobName); + } else { + ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false); + LOG.info("Created multi process cluster for job {}.", jobName); + } + Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop)); + } else { + LOG.info("Reuse existing cluster."); + } + ServiceLoader serviceLoader = ServiceLoader.load(JobScheduler.class); Iterator iterator = serviceLoader.iterator(); Preconditions.checkArgument(iterator.hasNext(), @@ -77,4 +96,10 @@ public class StreamingContext implements Serializable { public void withConfig(Map jobConfig) { this.jobConfig = jobConfig; } + + public void stop() { + if (Ray.internal() != null) { + ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph()); + } + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java index b3b43fd6c..fe0a3af1b 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java @@ -1,6 +1,7 @@ package io.ray.streaming.api.stream; +import io.ray.streaming.api.Language; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.FilterFunction; import io.ray.streaming.api.function.impl.FlatMapFunction; @@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator; import io.ray.streaming.operator.impl.KeyByOperator; import io.ray.streaming.operator.impl.MapOperator; import io.ray.streaming.operator.impl.SinkOperator; +import io.ray.streaming.python.stream.PythonDataStream; /** * Represents a stream of data. - * - * This class defines all the streaming operations. + *

This class defines all the streaming operations. * * @param Type of data in the stream. */ -public class DataStream extends Stream { +public class DataStream extends Stream, T> { public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) { super(streamingContext, streamOperator); } - public DataStream(DataStream input, StreamOperator streamOperator) { + public DataStream(StreamingContext streamingContext, + StreamOperator streamOperator, + Partition partition) { + super(streamingContext, streamOperator, partition); + } + + public DataStream(DataStream input, StreamOperator streamOperator) { super(input, streamOperator); } + public DataStream(DataStream input, + StreamOperator streamOperator, + Partition partition) { + super(input, streamOperator, partition); + } + + /** + * Create a java stream that reference passed python stream. + * Changes in new stream will be reflected in referenced stream and vice versa + */ + public DataStream(PythonDataStream referencedStream) { + super(referencedStream); + } + /** * Apply a map function to this stream. * @@ -41,7 +62,7 @@ public class DataStream extends Stream { * @return A new DataStream. */ public DataStream map(MapFunction mapFunction) { - return new DataStream<>(this, new MapOperator(mapFunction)); + return new DataStream<>(this, new MapOperator<>(mapFunction)); } /** @@ -52,11 +73,11 @@ public class DataStream extends Stream { * @return A new DataStream */ public DataStream flatMap(FlatMapFunction flatMapFunction) { - return new DataStream(this, new FlatMapOperator(flatMapFunction)); + return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction)); } public DataStream filter(FilterFunction filterFunction) { - return new DataStream(this, new FilterOperator(filterFunction)); + return new DataStream<>(this, new FilterOperator<>(filterFunction)); } /** @@ -66,7 +87,7 @@ public class DataStream extends Stream { * @return A new UnionStream. */ public UnionStream union(DataStream other) { - return new UnionStream(this, null, other); + return new UnionStream<>(this, null, other); } /** @@ -93,7 +114,7 @@ public class DataStream extends Stream { * @return A new StreamSink. */ public DataStreamSink sink(SinkFunction sinkFunction) { - return new DataStreamSink<>(this, new SinkOperator(sinkFunction)); + return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction)); } /** @@ -104,7 +125,8 @@ public class DataStream extends Stream { * @return A new KeyDataStream. */ public KeyDataStream keyBy(KeyFunction keyFunction) { - return new KeyDataStream<>(this, new KeyByOperator(keyFunction)); + checkPartitionCall(); + return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction)); } /** @@ -113,8 +135,8 @@ public class DataStream extends Stream { * @return This stream. */ public DataStream broadcast() { - this.partition = new BroadcastPartition<>(); - return this; + checkPartitionCall(); + return setPartition(new BroadcastPartition<>()); } /** @@ -124,19 +146,32 @@ public class DataStream extends Stream { * @return This stream. */ public DataStream partitionBy(Partition partition) { - this.partition = partition; - return this; + checkPartitionCall(); + return setPartition(partition); } /** - * Set parallelism to current transformation. - * - * @param parallelism The parallelism to set. - * @return This stream. + * If parent stream is a python stream, we can't call partition related methods + * in the java stream. */ - public DataStream setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + private void checkPartitionCall() { + if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) { + throw new RuntimeException("Partition related methods can't be called on a " + + "java stream if parent stream is a python stream."); + } } + /** + * Convert this stream as a python stream. + * The converted stream and this stream are the same logical stream, which has same stream id. + * Changes in converted stream will be reflected in this stream and vice versa. + */ + public PythonDataStream asPythonStream() { + return new PythonDataStream(this); + } + + @Override + public Language getLanguage() { + return Language.JAVA; + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java index e19d1b027..e58bb420b 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java @@ -1,5 +1,6 @@ package io.ray.streaming.api.stream; +import io.ray.streaming.api.Language; import io.ray.streaming.operator.impl.SinkOperator; /** @@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator; */ public class DataStreamSink extends StreamSink { - public DataStreamSink(DataStream input, SinkOperator sinkOperator) { + public DataStreamSink(DataStream input, SinkOperator sinkOperator) { super(input, sinkOperator); - this.streamingContext.addSink(this); + getStreamingContext().addSink(this); } - public DataStreamSink setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + @Override + public Language getLanguage() { + return Language.JAVA; } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java index 9f3b353ca..87ccb5eaf 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java @@ -14,27 +14,26 @@ import java.util.Collection; */ public class DataStreamSource extends DataStream implements StreamSource { - public DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) { - super(streamingContext, new SourceOperator<>(sourceFunction)); - super.partition = new RoundRobinPartition<>(); + private DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) { + super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>()); + } + + public static DataStreamSource fromSource( + StreamingContext context, SourceFunction sourceFunction) { + return new DataStreamSource<>(context, sourceFunction); } /** * Build a DataStreamSource source from a collection. * * @param context Stream context. - * @param values A collection of values. - * @param The type of source data. + * @param values A collection of values. + * @param The type of source data. * @return A DataStreamSource. */ - public static DataStreamSource buildSource( + public static DataStreamSource fromCollection( StreamingContext context, Collection values) { - return new DataStreamSource(context, new CollectionSourceFunction(values)); + return new DataStreamSource<>(context, new CollectionSourceFunction<>(values)); } - @Override - public DataStreamSource setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; - } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java index ad48f2efa..68708b9e9 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java @@ -2,9 +2,12 @@ package io.ray.streaming.api.stream; import io.ray.streaming.api.function.impl.AggregateFunction; import io.ray.streaming.api.function.impl.ReduceFunction; +import io.ray.streaming.api.partition.Partition; import io.ray.streaming.api.partition.impl.KeyPartition; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.impl.ReduceOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonKeyDataStream; /** * Represents a DataStream returned by a key-by operation. @@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator; * @param Type of the key. * @param Type of the data. */ +@SuppressWarnings("unchecked") public class KeyDataStream extends DataStream { public KeyDataStream(DataStream input, StreamOperator streamOperator) { - super(input, streamOperator); - this.partition = new KeyPartition(); + super(input, streamOperator, (Partition) new KeyPartition()); + } + + /** + * Create a java stream that reference passed python stream. + * Changes in new stream will be reflected in referenced stream and vice versa + */ + public KeyDataStream(PythonDataStream referencedStream) { + super(referencedStream); } /** @@ -41,8 +52,13 @@ public class KeyDataStream extends DataStream { return new DataStream<>(this, null); } - public KeyDataStream setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + /** + * Convert this stream as a python stream. + * The converted stream and this stream are the same logical stream, which has same stream id. + * Changes in converted stream will be reflected in this stream and vice versa. + */ + public PythonKeyDataStream asPythonStream() { + return new PythonKeyDataStream(this); } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java index 791124c41..4c74780cd 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java @@ -1,58 +1,99 @@ package io.ray.streaming.api.stream; +import com.google.common.base.Preconditions; +import io.ray.streaming.api.Language; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.api.partition.impl.RoundRobinPartition; +import io.ray.streaming.operator.Operator; import io.ray.streaming.operator.StreamOperator; -import io.ray.streaming.python.PythonOperator; import io.ray.streaming.python.PythonPartition; -import io.ray.streaming.python.stream.PythonStream; import java.io.Serializable; /** * Abstract base class of all stream types. * + * @param Type of stream class * @param Type of the data in the stream. */ -public abstract class Stream implements Serializable { - protected int id; - protected int parallelism = 1; - protected StreamOperator operator; - protected Stream inputStream; - protected StreamingContext streamingContext; - protected Partition partition; +public abstract class Stream, T> + implements Serializable { + private final int id; + private final StreamingContext streamingContext; + private final Stream inputStream; + private final StreamOperator operator; + private int parallelism = 1; + private Partition partition; + private Stream originalStream; - @SuppressWarnings("unchecked") public Stream(StreamingContext streamingContext, StreamOperator streamOperator) { + this(streamingContext, null, streamOperator, + selectPartition(streamOperator)); + } + + public Stream(StreamingContext streamingContext, + StreamOperator streamOperator, + Partition partition) { + this(streamingContext, null, streamOperator, partition); + } + + public Stream(Stream inputStream, StreamOperator streamOperator) { + this(inputStream.getStreamingContext(), inputStream, streamOperator, + selectPartition(streamOperator)); + } + + public Stream(Stream inputStream, StreamOperator streamOperator, Partition partition) { + this(inputStream.getStreamingContext(), inputStream, streamOperator, partition); + } + + protected Stream(StreamingContext streamingContext, + Stream inputStream, + StreamOperator streamOperator, + Partition partition) { this.streamingContext = streamingContext; + this.inputStream = inputStream; this.operator = streamOperator; + this.partition = partition; this.id = streamingContext.generateId(); - if (streamOperator instanceof PythonOperator) { - this.partition = PythonPartition.RoundRobinPartition; - } else { - this.partition = new RoundRobinPartition<>(); + if (inputStream != null) { + this.parallelism = inputStream.getParallelism(); } } - public Stream(Stream inputStream, StreamOperator streamOperator) { - this.inputStream = inputStream; - this.parallelism = inputStream.getParallelism(); - this.streamingContext = this.inputStream.getStreamingContext(); - this.operator = streamOperator; - this.id = streamingContext.generateId(); - this.partition = selectPartition(); + /** + * Create a proxy stream of original stream. + * Changes in new stream will be reflected in original stream and vice versa + */ + protected Stream(Stream originalStream) { + this.originalStream = originalStream; + this.id = originalStream.getId(); + this.streamingContext = originalStream.getStreamingContext(); + this.inputStream = originalStream.getInputStream(); + this.operator = originalStream.getOperator(); } @SuppressWarnings("unchecked") - private Partition selectPartition() { - if (inputStream instanceof PythonStream) { - return PythonPartition.RoundRobinPartition; - } else { - return new RoundRobinPartition<>(); + private static Partition selectPartition(Operator operator) { + switch (operator.getLanguage()) { + case PYTHON: + return (Partition) PythonPartition.RoundRobinPartition; + case JAVA: + return new RoundRobinPartition<>(); + default: + throw new UnsupportedOperationException( + "Unsupported language " + operator.getLanguage()); } } - public Stream getInputStream() { + public int getId() { + return id; + } + + public StreamingContext getStreamingContext() { + return streamingContext; + } + + public Stream getInputStream() { return inputStream; } @@ -60,32 +101,47 @@ public abstract class Stream implements Serializable { return operator; } - public void setOperator(StreamOperator operator) { - this.operator = operator; - } - - public StreamingContext getStreamingContext() { - return streamingContext; + @SuppressWarnings("unchecked") + private S self() { + return (S) this; } public int getParallelism() { - return parallelism; + return originalStream != null ? originalStream.getParallelism() : parallelism; } - public Stream setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; - } - - public int getId() { - return id; + public S setParallelism(int parallelism) { + if (originalStream != null) { + originalStream.setParallelism(parallelism); + } else { + this.parallelism = parallelism; + } + return self(); } + @SuppressWarnings("unchecked") public Partition getPartition() { - return partition; + return originalStream != null ? originalStream.getPartition() : partition; } - public void setPartition(Partition partition) { - this.partition = partition; + @SuppressWarnings("unchecked") + protected S setPartition(Partition partition) { + if (originalStream != null) { + originalStream.setPartition(partition); + } else { + this.partition = partition; + } + return self(); } + + public boolean isProxyStream() { + return originalStream != null; + } + + public Stream getOriginalStream() { + Preconditions.checkArgument(isProxyStream()); + return originalStream; + } + + public abstract Language getLanguage(); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java index 944b93eae..f03b1baa4 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java @@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator; * * @param Type of the input data of this sink. */ -public class StreamSink extends Stream { - public StreamSink(Stream inputStream, StreamOperator streamOperator) { +public abstract class StreamSink extends Stream, T> { + public StreamSink(Stream inputStream, StreamOperator streamOperator) { super(inputStream, streamOperator); } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java index ed7434c5c..6dd559ce7 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java @@ -11,15 +11,15 @@ import java.util.List; */ public class UnionStream extends DataStream { - private List unionStreams; + private List> unionStreams; - public UnionStream(DataStream input, StreamOperator streamOperator, DataStream other) { + public UnionStream(DataStream input, StreamOperator streamOperator, DataStream other) { super(input, streamOperator); this.unionStreams = new ArrayList<>(); this.unionStreams.add(other); } - public List getUnionStreams() { + public List> getUnionStreams() { return unionStreams; } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java index 675cad1ea..e670e5ea3 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java @@ -1,5 +1,6 @@ package io.ray.streaming.jobgraph; +import io.ray.streaming.api.Language; import java.io.Serializable; import java.util.ArrayList; import java.util.List; @@ -97,4 +98,14 @@ public class JobGraph implements Serializable { } } + public boolean isCrossLanguageGraph() { + Language language = jobVertexList.get(0).getLanguage(); + for (JobVertex jobVertex : jobVertexList) { + if (jobVertex.getLanguage() != language) { + return true; + } + } + return false; + } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java index b8d5af9a4..d0f6a7dc3 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java @@ -1,5 +1,6 @@ package io.ray.streaming.jobgraph; +import com.google.common.base.Preconditions; import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.Stream; import io.ray.streaming.api.stream.StreamSink; @@ -10,8 +11,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class JobGraphBuilder { + private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class); private JobGraph jobGraph; @@ -41,12 +45,19 @@ public class JobGraphBuilder { } private void processStream(Stream stream) { + while (stream.isProxyStream()) { + // Proxy stream and original stream are the same logical stream, both refer to the + // same data flow transformation. We should skip proxy stream to avoid applying same + // transformation multiple times. + LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId()); + stream = stream.getOriginalStream(); + } + StreamOperator streamOperator = stream.getOperator(); + Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(), + "Reference stream should be skipped."); int vertexId = stream.getId(); int parallelism = stream.getParallelism(); - - StreamOperator streamOperator = stream.getOperator(); - JobVertex jobVertex = null; - + JobVertex jobVertex; if (stream instanceof StreamSink) { jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator); Stream parentStream = stream.getInputStream(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java index d91b4cbd5..c99ec9959 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java @@ -1,6 +1,8 @@ package io.ray.streaming.message; +import java.util.Objects; + public class KeyRecord extends Record { private K key; @@ -17,4 +19,24 @@ public class KeyRecord extends Record { public void setKey(K key) { this.key = key; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + KeyRecord keyRecord = (KeyRecord) o; + return Objects.equals(key, keyRecord.key); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), key); + } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java deleted file mode 100644 index a943dcb9d..000000000 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java +++ /dev/null @@ -1,64 +0,0 @@ -package io.ray.streaming.message; - -import com.google.common.collect.Lists; -import java.io.Serializable; -import java.util.List; - -public class Message implements Serializable { - - private int taskId; - private long batchId; - private String stream; - private List recordList; - - public Message(int taskId, long batchId, String stream, List recordList) { - this.taskId = taskId; - this.batchId = batchId; - this.stream = stream; - this.recordList = recordList; - } - - public Message(int taskId, long batchId, String stream, Record record) { - this.taskId = taskId; - this.batchId = batchId; - this.stream = stream; - this.recordList = Lists.newArrayList(record); - } - - public int getTaskId() { - return taskId; - } - - public void setTaskId(int taskId) { - this.taskId = taskId; - } - - public long getBatchId() { - return batchId; - } - - public void setBatchId(long batchId) { - this.batchId = batchId; - } - - public String getStream() { - return stream; - } - - public void setStream(String stream) { - this.stream = stream; - } - - public List getRecordList() { - return recordList; - } - - public void setRecordList(List recordList) { - this.recordList = recordList; - } - - public Record getRecord(int index) { - return recordList.get(0); - } - -} diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java index 8d0ca368b..c86e47645 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java @@ -1,6 +1,7 @@ package io.ray.streaming.message; import java.io.Serializable; +import java.util.Objects; public class Record implements Serializable { @@ -27,6 +28,24 @@ public class Record implements Serializable { this.stream = stream; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Record record = (Record) o; + return Objects.equals(stream, record.stream) && + Objects.equals(value, record.value); + } + + @Override + public int hashCode() { + return Objects.hash(stream, value); + } + @Override public String toString() { return value.toString(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java index 31c82b43e..21533de70 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java @@ -1,6 +1,8 @@ package io.ray.streaming.python; +import com.google.common.base.Preconditions; import io.ray.streaming.api.function.Function; +import org.apache.commons.lang3.StringUtils; /** * Represents a user defined python function. @@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function; * *

If the python data stream api is invoked from python, `function` will be not null.

*

If the python data stream api is invoked from java, `moduleName` and - * `className`/`functionName` will be not null.

+ * `functionName` will be not null.

*

- * TODO serialize to bytes using protobuf */ public class PythonFunction implements Function { public enum FunctionInterface { @@ -38,23 +39,43 @@ public class PythonFunction implements Function { } } - private byte[] function; - private String moduleName; - private String className; - private String functionName; + // null if this function is constructed from moduleName/functionName. + private final byte[] function; + // null if this function is constructed from serialized python function. + private final String moduleName; + // null if this function is constructed from serialized python function. + private final String functionName; /** * FunctionInterface can be used to validate python function, * and look up operator class from FunctionInterface. */ private String functionInterface; - private PythonFunction(byte[] function, - String moduleName, - String className, - String functionName) { + /** + * Create a {@link PythonFunction} from a serialized streaming python function. + * + * @param function serialized streaming python function from python driver. + */ + public PythonFunction(byte[] function) { + Preconditions.checkNotNull(function); this.function = function; + this.moduleName = null; + this.functionName = null; + } + + /** + * Create a {@link PythonFunction} from a moduleName and streaming function name. + * + * @param moduleName module name of streaming function. + * @param functionName function name of streaming function. {@code functionName} is the name + * of a python function, or class name of subclass of `ray.streaming.function.` + */ + public PythonFunction(String moduleName, + String functionName) { + Preconditions.checkArgument(StringUtils.isNotBlank(moduleName)); + Preconditions.checkArgument(StringUtils.isNotBlank(functionName)); + this.function = null; this.moduleName = moduleName; - this.className = className; this.functionName = functionName; } @@ -70,10 +91,6 @@ public class PythonFunction implements Function { return moduleName; } - public String getClassName() { - return className; - } - public String getFunctionName() { return functionName; } @@ -82,34 +99,4 @@ public class PythonFunction implements Function { return functionInterface; } - /** - * Create a {@link PythonFunction} using python serialized function - * - * @param function serialized python function sent from python driver - */ - public static PythonFunction fromFunction(byte[] function) { - return new PythonFunction(function, null, null, null); - } - - /** - * Create a {@link PythonFunction} using moduleName and - * className. - * - * @param moduleName python module name - * @param className python class name - */ - public static PythonFunction fromClassName(String moduleName, String className) { - return new PythonFunction(null, moduleName, className, null); - } - - /** - * Create a {@link PythonFunction} using moduleName and - * functionName. - * - * @param moduleName python module name - * @param functionName python function name - */ - public static PythonFunction fromFunctionName(String moduleName, String functionName) { - return new PythonFunction(null, moduleName, null, functionName); - } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java index c4548031b..6d8de051f 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java @@ -1,6 +1,8 @@ package io.ray.streaming.python; +import com.google.common.base.Preconditions; import io.ray.streaming.api.partition.Partition; +import org.apache.commons.lang3.StringUtils; /** * Represents a python partition function. @@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition; * If this object is constructed from moduleName and className/functionName, * python worker will use `importlib` to load python partition function. *

- * TODO serialize to bytes using protobuf */ -public class PythonPartition implements Partition { +public class PythonPartition implements Partition { public static final PythonPartition BroadcastPartition = new PythonPartition( - "ray.streaming.partition", "BroadcastPartition", null); + "ray.streaming.partition", "BroadcastPartition"); public static final PythonPartition KeyPartition = new PythonPartition( - "ray.streaming.partition", "KeyPartition", null); + "ray.streaming.partition", "KeyPartition"); public static final PythonPartition RoundRobinPartition = new PythonPartition( - "ray.streaming.partition", "RoundRobinPartition", null); + "ray.streaming.partition", "RoundRobinPartition"); private byte[] partition; private String moduleName; - private String className; private String functionName; public PythonPartition(byte[] partition) { + Preconditions.checkNotNull(partition); this.partition = partition; } - public PythonPartition(String moduleName, String className, String functionName) { + /** + * Create a python partition from a moduleName and partition function name + * @param moduleName module name of python partition + * @param functionName function/class name of the partition function. + */ + public PythonPartition(String moduleName, String functionName) { + Preconditions.checkArgument(StringUtils.isNotBlank(moduleName)); + Preconditions.checkArgument(StringUtils.isNotBlank(functionName)); this.moduleName = moduleName; - this.className = className; this.functionName = functionName; } @@ -53,10 +60,6 @@ public class PythonPartition implements Partition { return moduleName; } - public String getClassName() { - return className; - } - public String getFunctionName() { return functionName; } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java index da8be7368..8911fde13 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonDataStream.java @@ -1,6 +1,9 @@ package io.ray.streaming.python.stream; +import io.ray.streaming.api.Language; import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.partition.Partition; +import io.ray.streaming.api.stream.DataStream; import io.ray.streaming.api.stream.Stream; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction.FunctionInterface; @@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition; /** * Represents a stream of data whose transformations will be executed in python. */ -public class PythonDataStream extends Stream implements PythonStream { +public class PythonDataStream extends Stream implements PythonStream { protected PythonDataStream(StreamingContext streamingContext, PythonOperator pythonOperator) { super(streamingContext, pythonOperator); } + protected PythonDataStream(StreamingContext streamingContext, + PythonOperator pythonOperator, + Partition partition) { + super(streamingContext, pythonOperator, partition); + } + public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) { super(input, pythonOperator); } - protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) { - super(inputStream, pythonOperator); + public PythonDataStream(PythonDataStream input, + PythonOperator pythonOperator, + Partition partition) { + super(input, pythonOperator, partition); + } + + /** + * Create a python stream that reference passed java stream. + * Changes in new stream will be reflected in referenced stream and vice versa + */ + public PythonDataStream(DataStream referencedStream) { + super(referencedStream); + } + + public PythonDataStream map(String moduleName, String funcName) { + return map(new PythonFunction(moduleName, funcName)); } /** @@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream { return new PythonDataStream(this, new PythonOperator(func)); } + public PythonDataStream flatMap(String moduleName, String funcName) { + return flatMap(new PythonFunction(moduleName, funcName)); + } + /** * Apply a flat-map function to this stream. * @@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream { return new PythonDataStream(this, new PythonOperator(func)); } + public PythonDataStream filter(String moduleName, String funcName) { + return filter(new PythonFunction(moduleName, funcName)); + } + /** * Apply a filter function to this stream. * @@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream { return new PythonDataStream(this, new PythonOperator(func)); } + public PythonStreamSink sink(String moduleName, String funcName) { + return sink(new PythonFunction(moduleName, funcName)); + } + /** * Apply a sink function and get a StreamSink. * @@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream { return new PythonStreamSink(this, new PythonOperator(func)); } + public PythonKeyDataStream keyBy(String moduleName, String funcName) { + return keyBy(new PythonFunction(moduleName, funcName)); + } + /** * Apply a key-by function to this stream. * @@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream { * @return A new KeyDataStream. */ public PythonKeyDataStream keyBy(PythonFunction func) { + checkPartitionCall(); func.setFunctionInterface(FunctionInterface.KEY_FUNCTION); return new PythonKeyDataStream(this, new PythonOperator(func)); } @@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream { * @return This stream. */ public PythonDataStream broadcast() { - this.partition = PythonPartition.BroadcastPartition; - return this; + checkPartitionCall(); + return setPartition(PythonPartition.BroadcastPartition); } /** @@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream { * @return This stream. */ public PythonDataStream partitionBy(PythonPartition partition) { - this.partition = partition; - return this; + checkPartitionCall(); + return setPartition(partition); } /** - * Set parallelism to current transformation. - * - * @param parallelism The parallelism to set. - * @return This stream. + * If parent stream is a python stream, we can't call partition related methods + * in the java stream. */ - public PythonDataStream setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + private void checkPartitionCall() { + if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) { + throw new RuntimeException("Partition related methods can't be called on a " + + "python stream if parent stream is a java stream."); + } + } + + /** + * Convert this stream as a java stream. + * The converted stream and this stream are the same logical stream, which has same stream id. + * Changes in converted stream will be reflected in this stream and vice versa. + */ + public DataStream asJavaStream() { + return new DataStream<>(this); + } + + @Override + public Language getLanguage() { + return Language.PYTHON; } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java index 780a63866..a095b761b 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonKeyDataStream.java @@ -1,5 +1,7 @@ package io.ray.streaming.python.stream; +import io.ray.streaming.api.stream.DataStream; +import io.ray.streaming.api.stream.KeyDataStream; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonFunction.FunctionInterface; import io.ray.streaming.python.PythonOperator; @@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition; /** * Represents a python DataStream returned by a key-by operation. */ -public class PythonKeyDataStream extends PythonDataStream implements PythonStream { +@SuppressWarnings("unchecked") +public class PythonKeyDataStream extends PythonDataStream implements PythonStream { public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) { - super(input, pythonOperator); - this.partition = PythonPartition.KeyPartition; + super(input, pythonOperator, PythonPartition.KeyPartition); + } + + /** + * Create a python stream that reference passed python stream. + * Changes in new stream will be reflected in referenced stream and vice versa + */ + public PythonKeyDataStream(DataStream referencedStream) { + super(referencedStream); + } + + public PythonDataStream reduce(String moduleName, String funcName) { + return reduce(new PythonFunction(moduleName, funcName)); } /** @@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea return new PythonDataStream(this, new PythonOperator(func)); } - public PythonKeyDataStream setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + /** + * Convert this stream as a java stream. + * The converted stream and this stream are the same logical stream, which has same stream id. + * Changes in converted stream will be reflected in this stream and vice versa. + */ + public KeyDataStream asJavaStream() { + return new KeyDataStream(this); } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java index d8bfddaa5..5a7dd60ab 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSink.java @@ -1,5 +1,6 @@ package io.ray.streaming.python.stream; +import io.ray.streaming.api.Language; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.python.PythonOperator; @@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator; public class PythonStreamSink extends StreamSink implements PythonStream { public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) { super(input, sinkOperator); - this.streamingContext.addSink(this); + getStreamingContext().addSink(this); } - public PythonStreamSink setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + @Override + public Language getLanguage() { + return Language.PYTHON; } } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java index 25d2ebba0..a123c5cf8 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/stream/PythonStreamSource.java @@ -13,17 +13,12 @@ import io.ray.streaming.python.PythonPartition; public class PythonStreamSource extends PythonDataStream implements StreamSource { private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) { - super(streamingContext, new PythonOperator(sourceFunction)); - super.partition = PythonPartition.RoundRobinPartition; - } - - public PythonStreamSource setParallelism(int parallelism) { - this.parallelism = parallelism; - return this; + super(streamingContext, new PythonOperator(sourceFunction), + PythonPartition.RoundRobinPartition); } public static PythonStreamSource from(StreamingContext streamingContext, - PythonFunction sourceFunction) { + PythonFunction sourceFunction) { sourceFunction.setFunctionInterface(FunctionInterface.SOURCE_FUNCTION); return new PythonStreamSource(streamingContext, sourceFunction); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java index 9209d5d0d..9862c1f92 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/util/Config.java @@ -21,7 +21,6 @@ public class Config { public static final String CHANNEL_TYPE = "channel_type"; public static final String MEMORY_CHANNEL = "memory_channel"; public static final String NATIVE_CHANNEL = "native_channel"; - public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL; public static final String CHANNEL_SIZE = "channel_size"; public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8)); public static final String IS_RECREATE = "streaming.is_recreate"; diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java new file mode 100644 index 000000000..fe05b52d5 --- /dev/null +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/api/stream/StreamTest.java @@ -0,0 +1,40 @@ +package io.ray.streaming.api.stream; + +import static org.testng.Assert.assertEquals; + +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.operator.impl.MapOperator; +import io.ray.streaming.python.stream.PythonDataStream; +import io.ray.streaming.python.stream.PythonKeyDataStream; +import org.testng.annotations.Test; + +@SuppressWarnings("unchecked") +public class StreamTest { + + @Test + public void testReferencedDataStream() { + DataStream dataStream = new DataStream(StreamingContext.buildContext(), + new MapOperator(value -> null)); + PythonDataStream pythonDataStream = dataStream.asPythonStream(); + DataStream javaStream = pythonDataStream.asJavaStream(); + assertEquals(dataStream.getId(), pythonDataStream.getId()); + assertEquals(dataStream.getId(), javaStream.getId()); + javaStream.setParallelism(10); + assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism()); + assertEquals(dataStream.getParallelism(), javaStream.getParallelism()); + } + + @Test + public void testReferencedKeyDataStream() { + DataStream dataStream = new DataStream(StreamingContext.buildContext(), + new MapOperator(value -> null)); + KeyDataStream keyDataStream = dataStream.keyBy(value -> null); + PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream(); + KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream(); + assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId()); + assertEquals(keyDataStream.getId(), javaKeyDataStream.getId()); + javaKeyDataStream.setParallelism(10); + assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism()); + assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism()); + } +} \ No newline at end of file diff --git a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java index cb1974b4b..1f9c367df 100644 --- a/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java +++ b/streaming/java/streaming-api/src/test/java/io/ray/streaming/jobgraph/JobGraphBuilderTest.java @@ -38,7 +38,7 @@ public class JobGraphBuilderTest { public JobGraph buildDataSyncJobGraph() { StreamingContext streamingContext = StreamingContext.buildContext(); - DataStream dataStream = DataStreamSource.buildSource(streamingContext, + DataStream dataStream = DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("a", "b", "c")); StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); @@ -73,7 +73,7 @@ public class JobGraphBuilderTest { public JobGraph buildKeyByJobGraph() { StreamingContext streamingContext = StreamingContext.buildContext(); - DataStream dataStream = DataStreamSource.buildSource(streamingContext, + DataStream dataStream = DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("1", "2", "3", "4")); StreamSink streamSink = dataStream.keyBy(x -> x) .sink(x -> LOG.info(x)); diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml index d2a8577a5..19633c4be 100644 --- a/streaming/java/streaming-runtime/pom.xml +++ b/streaming/java/streaming-runtime/pom.xml @@ -36,6 +36,11 @@ flatbuffers-java 1.9.0.1 + + com.google.code.findbugs + jsr305 + 3.0.2 + com.google.guava guava @@ -56,6 +61,11 @@ owner 1.0.10 + + org.apache.commons + commons-lang3 + 3.4 + org.mockito mockito-all @@ -71,11 +81,6 @@ powermock-api-mockito 1.6.6 - - org.powermock - powermock-core - 1.6.6 - org.powermock powermock-module-testng diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java index 7939e69b1..566dd15b8 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java @@ -1,9 +1,14 @@ package io.ray.streaming.runtime.core.collector; -import io.ray.runtime.serializer.Serializer; +import io.ray.api.BaseActor; +import io.ray.api.RayPyActor; +import io.ray.streaming.api.Language; import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.message.Record; +import io.ray.streaming.runtime.serialization.CrossLangSerializer; +import io.ray.streaming.runtime.serialization.JavaSerializer; +import io.ray.streaming.runtime.serialization.Serializer; import io.ray.streaming.runtime.transfer.ChannelID; import io.ray.streaming.runtime.transfer.DataWriter; import java.nio.ByteBuffer; @@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory; public class OutputCollector implements Collector { private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class); - private Partition partition; - private DataWriter writer; - private ChannelID[] outputQueues; + private final DataWriter writer; + private final ChannelID[] outputQueues; + private final Collection targetActors; + private final Language[] targetLanguages; + private final Partition partition; + private final Serializer javaSerializer = new JavaSerializer(); + private final Serializer crossLangSerializer = new CrossLangSerializer(); - public OutputCollector(Collection outputQueueIds, - DataWriter writer, + public OutputCollector(DataWriter writer, + Collection outputQueueIds, + Collection targetActors, Partition partition) { - this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new); this.writer = writer; + this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new); + this.targetActors = targetActors; + this.targetLanguages = targetActors.stream() + .map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA) + .toArray(Language[]::new); this.partition = partition; LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.", outputQueueIds, this.partition); @@ -31,9 +45,32 @@ public class OutputCollector implements Collector { @Override public void collect(Record record) { int[] partitions = this.partition.partition(record, outputQueues.length); - ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft()); + ByteBuffer javaBuffer = null; + ByteBuffer crossLangBuffer = null; for (int partition : partitions) { - writer.write(outputQueues[partition], msgBuffer); + if (targetLanguages[partition] == Language.JAVA) { + // avoid repeated serialization + if (javaBuffer == null) { + byte[] bytes = javaSerializer.serialize(record); + javaBuffer = ByteBuffer.allocate(1 + bytes.length); + javaBuffer.put(Serializer.JAVA_TYPE_ID); + // TODO(chaokunyang) remove copy + javaBuffer.put(bytes); + javaBuffer.flip(); + } + writer.write(outputQueues[partition], javaBuffer.duplicate()); + } else { + // avoid repeated serialization + if (crossLangBuffer == null) { + byte[] bytes = crossLangSerializer.serialize(record); + crossLangBuffer = ByteBuffer.allocate(1 + bytes.length); + crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID); + // TODO(chaokunyang) remove copy + crossLangBuffer.put(bytes); + crossLangBuffer.flip(); + } + writer.write(outputQueues[partition], crossLangBuffer.duplicate()); + } } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java index d90e02463..b79926490 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode; import io.ray.streaming.runtime.core.graph.ExecutionTask; import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.generated.Streaming; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; import java.util.Arrays; public class GraphPbBuilder { @@ -74,11 +75,10 @@ public class GraphPbBuilder { private byte[] serializeFunction(Function function) { if (function instanceof PythonFunction) { PythonFunction pyFunc = (PythonFunction) function; - // function_bytes, module_name, class_name, function_name, function_interface + // function_bytes, module_name, function_name, function_interface return serializer.serialize(Arrays.asList( pyFunc.getFunction(), pyFunc.getModuleName(), - pyFunc.getClassName(), pyFunc.getFunctionName(), - pyFunc.getFunctionInterface() + pyFunc.getFunctionName(), pyFunc.getFunctionInterface() )); } else { return new byte[0]; @@ -88,10 +88,10 @@ public class GraphPbBuilder { private byte[] serializePartition(Partition partition) { if (partition instanceof PythonPartition) { PythonPartition pythonPartition = (PythonPartition) partition; - // partition_bytes, module_name, class_name, function_name + // partition_bytes, module_name, function_name return serializer.serialize(Arrays.asList( pythonPartition.getPartition(), pythonPartition.getModuleName(), - pythonPartition.getClassName(), pythonPartition.getFunctionName() + pythonPartition.getFunctionName() )); } else { return new byte[0]; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java index 80a487db0..826f1c935 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java @@ -1,16 +1,21 @@ package io.ray.streaming.runtime.python; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Primitives; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.python.PythonFunction; import io.ray.streaming.python.PythonPartition; import io.ray.streaming.python.stream.PythonStreamSource; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; import io.ray.streaming.runtime.util.ReflectionUtils; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; -import org.msgpack.core.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,7 +73,7 @@ public class PythonGateway { Preconditions.checkNotNull(streamingContext); try { PythonStreamSource pythonStreamSource = PythonStreamSource.from( - streamingContext, PythonFunction.fromFunction(pySourceFunc)); + streamingContext, new PythonFunction(pySourceFunc)); referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource); return serializer.serialize(getReferenceId(pythonStreamSource)); } catch (Exception e) { @@ -84,7 +89,7 @@ public class PythonGateway { } public byte[] createPyFunc(byte[] pyFunc) { - PythonFunction function = PythonFunction.fromFunction(pyFunc); + PythonFunction function = new PythonFunction(pyFunc); referenceMap.put(getReferenceId(function), function); return serializer.serialize(getReferenceId(function)); } @@ -98,15 +103,21 @@ public class PythonGateway { public byte[] callFunction(byte[] paramsBytes) { try { List params = (List) serializer.deserialize(paramsBytes); - params = processReferenceParameters(params); + params = processParameters(params); LOG.info("callFunction params {}", params); String className = (String) params.get(0); String funcName = (String) params.get(1); Class clz = Class.forName(className, true, this.getClass().getClassLoader()); - Method method = ReflectionUtils.findMethod(clz, funcName); + Class[] paramsTypes = params.subList(2, params.size()).stream() + .map(Object::getClass).toArray(Class[]::new); + Method method = findMethod(clz, funcName, paramsTypes); Object result = method.invoke(null, params.subList(2, params.size()).toArray()); - referenceMap.put(getReferenceId(result), result); - return serializer.serialize(getReferenceId(result)); + if (returnReference(result)) { + referenceMap.put(getReferenceId(result), result); + return serializer.serialize(getReferenceId(result)); + } else { + return serializer.serialize(result); + } } catch (Exception e) { throw new RuntimeException(e); } @@ -115,31 +126,78 @@ public class PythonGateway { public byte[] callMethod(byte[] paramsBytes) { try { List params = (List) serializer.deserialize(paramsBytes); - params = processReferenceParameters(params); + params = processParameters(params); LOG.info("callMethod params {}", params); Object obj = params.get(0); String methodName = (String) params.get(1); - Method method = ReflectionUtils.findMethod(obj.getClass(), methodName); + Class clz = obj.getClass(); + Class[] paramsTypes = params.subList(2, params.size()).stream() + .map(Object::getClass).toArray(Class[]::new); + Method method = findMethod(clz, methodName, paramsTypes); Object result = method.invoke(obj, params.subList(2, params.size()).toArray()); - referenceMap.put(getReferenceId(result), result); - return serializer.serialize(getReferenceId(result)); + if (returnReference(result)) { + referenceMap.put(getReferenceId(result), result); + return serializer.serialize(getReferenceId(result)); + } else { + return serializer.serialize(result); + } } catch (Exception e) { throw new RuntimeException(e); } } - private List processReferenceParameters(List params) { - return params.stream().map(this::processReferenceParameter) + private static Method findMethod(Class cls, String methodName, Class[] paramsTypes) { + List methods = ReflectionUtils.findMethods(cls, methodName); + if (methods.size() == 1) { + return methods.get(0); + } + // Convert all params types to primitive types if it's boxed type + Class[] unwrappedTypes = Arrays.stream(paramsTypes) + .map((Function) Primitives::unwrap) + .toArray(Class[]::new); + Optional any = methods.stream() + .filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) || + Arrays.equals(m.getParameterTypes(), unwrappedTypes)) + .findAny(); + Preconditions.checkArgument(any.isPresent(), + String.format("Method %s with type %s doesn't exist on class %s", + methodName, Arrays.toString(paramsTypes), cls)); + return any.get(); + } + + private static boolean returnReference(Object value) { + return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]); + } + + public byte[] newInstance(byte[] classNameBytes) { + String className = (String) serializer.deserialize(classNameBytes); + try { + Class clz = Class.forName(className, true, this.getClass().getClassLoader()); + Object instance = clz.newInstance(); + referenceMap.put(getReferenceId(instance), instance); + return serializer.serialize(getReferenceId(instance)); + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new IllegalArgumentException( + String.format("Create instance for class %s failed", className), e); + } + } + + private List processParameters(List params) { + return params.stream().map(this::processParameter) .collect(Collectors.toList()); } - private Object processReferenceParameter(Object o) { + private Object processParameter(Object o) { if (o instanceof String) { Object value = referenceMap.get(o); if (value != null) { return value; } } + // Since python can't represent byte/short, we convert all Byte/Short to Integer + if (o instanceof Byte || o instanceof Short) { + return ((Number) o).intValue(); + } return o; } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java index f1de23c8c..deaaf74b3 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java @@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler { public void schedule(JobGraph jobGraph, Map jobConfig) { this.jobConfig = jobConfig; this.jobGraph = jobGraph; - if (Ray.internal() == null) { - System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); - Ray.init(); - } ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph); List executionNodes = executionGraph.getExecutionNodeList(); boolean hasPythonNode = executionNodes.stream() - .allMatch(node -> node.getLanguage() == Language.PYTHON); + .anyMatch(node -> node.getLanguage() == Language.PYTHON); RemoteCall.ExecutionGraph executionGraphPb = null; if (hasPythonNode) { executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java index 171375ed2..04520b441 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java @@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule; import io.ray.api.BaseActor; import io.ray.api.Ray; +import io.ray.api.RayActor; +import io.ray.api.RayPyActor; import io.ray.api.function.PyActorClass; import io.ray.streaming.jobgraph.JobEdge; import io.ray.streaming.jobgraph.JobGraph; @@ -15,8 +17,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class TaskAssignerImpl implements TaskAssigner { + private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class); /** * Assign an optimized logical plan to execution graph. @@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner { private BaseActor createWorker(JobVertex jobVertex) { switch (jobVertex.getLanguage()) { - case PYTHON: - return Ray.createActor( + case PYTHON: { + RayPyActor worker = Ray.createActor( new PyActorClass("ray.streaming.runtime.worker", "JobWorker")); - case JAVA: - return Ray.createActor(JobWorker::new); + LOG.info("Created python worker {}", worker); + return worker; + } + case JAVA: { + RayActor worker = Ray.createActor(JobWorker::new); + LOG.info("Created java worker {}", worker); + return worker; + } default: throw new UnsupportedOperationException( "Unsupported language " + jobVertex.getLanguage()); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java new file mode 100644 index 000000000..17557b9ac --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java @@ -0,0 +1,62 @@ +package io.ray.streaming.runtime.serialization; + +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import java.util.Arrays; +import java.util.List; + +/** + * A serializer for cross-lang serialization between java/python. + * TODO implements a more sophisticated serialization framework + */ +public class CrossLangSerializer implements Serializer { + private static final byte RECORD_TYPE_ID = 0; + private static final byte KEY_RECORD_TYPE_ID = 1; + + private MsgPackSerializer msgPackSerializer = new MsgPackSerializer(); + + public byte[] serialize(Object object) { + Record record = (Record) object; + Object value = record.getValue(); + Class clz = record.getClass(); + if (clz == Record.class) { + return msgPackSerializer.serialize(Arrays.asList( + RECORD_TYPE_ID, record.getStream(), value)); + } else if (clz == KeyRecord.class) { + KeyRecord keyRecord = (KeyRecord) record; + Object key = keyRecord.getKey(); + return msgPackSerializer.serialize(Arrays.asList( + KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value)); + } else { + throw new UnsupportedOperationException( + String.format("Serialize %s is unsupported.", record)); + } + } + + @SuppressWarnings("unchecked") + public Record deserialize(byte[] bytes) { + List list = (List) msgPackSerializer.deserialize(bytes); + Byte typeId = (Byte) list.get(0); + switch (typeId) { + case RECORD_TYPE_ID: { + String stream = (String) list.get(1); + Object value = list.get(2); + Record record = new Record(value); + record.setStream(stream); + return record; + } + case KEY_RECORD_TYPE_ID: { + String stream = (String) list.get(1); + Object key = list.get(2); + Object value = list.get(3); + KeyRecord keyRecord = new KeyRecord(key, value); + keyRecord.setStream(stream); + return keyRecord; + } + default: + throw new UnsupportedOperationException("Unsupported type " + typeId); + + } + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java new file mode 100644 index 000000000..d7a1a2649 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java @@ -0,0 +1,15 @@ +package io.ray.streaming.runtime.serialization; + +import io.ray.runtime.serializer.FstSerializer; + +public class JavaSerializer implements Serializer { + @Override + public byte[] serialize(Object object) { + return FstSerializer.encode(object); + } + + @Override + public T deserialize(byte[] bytes) { + return FstSerializer.decode(bytes); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java similarity index 90% rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java index 20415a438..2fc9a2c37 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java @@ -1,4 +1,4 @@ -package io.ray.streaming.runtime.python; +package io.ray.streaming.runtime.serialization; import com.google.common.io.BaseEncoding; import java.util.ArrayList; @@ -31,6 +31,10 @@ public class MsgPackSerializer { Class clz = obj.getClass(); if (clz == Boolean.class) { packer.packBoolean((Boolean) obj); + } else if (clz == Byte.class) { + packer.packByte((Byte) obj); + } else if (clz == Short.class) { + packer.packShort((Short) obj); } else if (clz == Integer.class) { packer.packInt((Integer) obj); } else if (clz == Long.class) { @@ -84,7 +88,11 @@ public class MsgPackSerializer { return value.asBooleanValue().getBoolean(); case INTEGER: IntegerValue iv = value.asIntegerValue(); - if (iv.isInIntRange()) { + if (iv.isInByteRange()) { + return iv.toByte(); + } else if (iv.isInShortRange()) { + return iv.toShort(); + } else if (iv.isInIntRange()) { return iv.toInt(); } else if (iv.isInLongRange()) { return iv.toLong(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java new file mode 100644 index 000000000..b3a3184d7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java @@ -0,0 +1,12 @@ +package io.ray.streaming.runtime.serialization; + +public interface Serializer { + byte CROSS_LANG_TYPE_ID = 0; + byte JAVA_TYPE_ID = 1; + byte PYTHON_TYPE_ID = 2; + + byte[] serialize(Object object); + + T deserialize(byte[] bytes); + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java index b152ca3b7..8506560d9 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java @@ -20,7 +20,7 @@ import java.util.Map; */ public class ChannelCreationParametersBuilder { - public class Parameter { + public static class Parameter { private ActorId actorId; private FunctionDescriptor asyncFunctionDescriptor; @@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder { parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc); parameter.setSyncFunctionDescriptor(pySyncFunctionDesc); } else { - Preconditions.checkArgument(false, "Invalid actor type"); + throw new IllegalArgumentException("Invalid actor type"); } parameters.add(parameter); } @@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder { } public String toString() { - String str = ""; + StringBuilder str = new StringBuilder(); for (Parameter param : parameters) { - str += param.toString(); + str.append(param.toString()); } - return str; + return str.toString(); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java index 64e17f59c..b69396b43 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java @@ -40,7 +40,7 @@ public class DataReader { } long timerInterval = Long.parseLong( conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1")); - String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); + String channelType = conf.get(Config.CHANNEL_TYPE); boolean isMock = false; if (Config.MEMORY_CHANNEL.equals(channelType)) { isMock = true; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java index 25e02940e..39678aebb 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java @@ -37,7 +37,7 @@ public class DataWriter { Map conf) { Preconditions.checkArgument(!outputChannels.isEmpty()); Preconditions.checkArgument(outputChannels.size() == toActors.size()); - ChannelCreationParametersBuilder initialParameters = + ChannelCreationParametersBuilder initParameters = new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); byte[][] outputChannelsBytes = outputChannels.stream() .map(ChannelID::idStrToBytes).toArray(byte[][]::new); @@ -47,13 +47,14 @@ public class DataWriter { for (int i = 0; i < outputChannels.size(); i++) { msgIds[i] = 0; } - String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); + String channelType = conf.get(Config.CHANNEL_TYPE); boolean isMock = false; - if (Config.MEMORY_CHANNEL.equals(channelType)) { + if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) { isMock = true; + LOGGER.info("Using memory channel"); } this.nativeWriterPtr = createWriterNative( - initialParameters, + initParameters, outputChannelsBytes, msgIds, channelSize, diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java index d3f26a06a..5852220af 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java @@ -19,6 +19,7 @@ public class ReflectionUtils { /** * For covariant return type, return the most specific method. + * * @return all methods named by {@code methodName}, */ public static List findMethods(Class cls, String methodName) { diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java index 75d587c5a..2433d18e9 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java @@ -1,5 +1,6 @@ package io.ray.streaming.runtime.worker; +import io.ray.api.Ray; import io.ray.streaming.runtime.core.graph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.ExecutionNode; import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; @@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext; import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; import io.ray.streaming.runtime.worker.tasks.StreamTask; -import io.ray.streaming.util.Config; - import java.io.Serializable; import java.util.Map; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory; */ public class JobWorker implements Serializable { private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class); + // special flag to indicate this actor not ready + private static final byte[] NOT_READY_FLAG = new byte[4]; static { EnvUtil.loadNativeLibraries(); @@ -53,12 +53,11 @@ public class JobWorker implements Serializable { this.nodeType = executionNode.getNodeType(); this.streamProcessor = ProcessBuilder - .buildProcessor(executionNode.getStreamOperator()); - LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor); + .buildProcessor(executionNode.getStreamOperator()); + LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.", + EnvUtil.getJvmPid(), taskId, streamProcessor); - String channelType = (String) this.config.getOrDefault( - Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); - if (channelType.equals(Config.NATIVE_CHANNEL)) { + if (!Ray.getRuntimeContext().isSingleProcess()) { transferHandler = new TransferHandler(); } task = createStreamTask(); @@ -124,6 +123,9 @@ public class JobWorker implements Serializable { * and receive result from this actor */ public byte[] onReaderMessageSync(byte[] buffer) { + if (transferHandler == null) { + return NOT_READY_FLAG; + } return transferHandler.onReaderMessageSync(buffer); } @@ -139,6 +141,9 @@ public class JobWorker implements Serializable { * and receive result from this actor */ public byte[] onWriterMessageSync(byte[] buffer) { + if (transferHandler == null) { + return NOT_READY_FLAG; + } return transferHandler.onWriterMessageSync(buffer); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java index a3fd7a470..8d642aeef 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -1,7 +1,9 @@ package io.ray.streaming.runtime.worker.tasks; -import io.ray.runtime.serializer.Serializer; import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.serialization.CrossLangSerializer; +import io.ray.streaming.runtime.serialization.JavaSerializer; +import io.ray.streaming.runtime.serialization.Serializer; import io.ray.streaming.runtime.transfer.Message; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.util.Config; @@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask { private volatile boolean running = true; private volatile boolean stopped = false; private long readTimeoutMillis; + private final io.ray.streaming.runtime.serialization.Serializer javaSerializer; + private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer; public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) { super(taskId, processor, streamWorker); readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig() .getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS)); + javaSerializer = new JavaSerializer(); + crossLangSerializer = new CrossLangSerializer(); } @Override @@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask { while (running) { Message item = reader.read(readTimeoutMillis); if (item != null) { - byte[] bytes = new byte[item.body().remaining()]; + byte[] bytes = new byte[item.body().remaining() - 1]; + byte typeId = item.body().get(); item.body().get(bytes); - Object obj = Serializer.decode(bytes, Object.class); + Object obj; + if (typeId == Serializer.JAVA_TYPE_ID) { + obj = javaSerializer.deserialize(bytes); + } else { + obj = crossLangSerializer.deserialize(bytes); + } processor.process(obj); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java index d16cc029d..ca2e6aa99 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -26,7 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; public abstract class StreamTask implements Runnable { - private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); protected int taskId; @@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable { String queueSize = worker.getConfig() .getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT); queueConf.put(Config.CHANNEL_SIZE, queueSize); - String channelType = worker.getConfig() - .getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL); + String channelType = Ray.getRuntimeContext().isSingleProcess() ? + Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL; queueConf.put(Config.CHANNEL_TYPE, channelType); ExecutionGraph executionGraph = worker.getExecutionGraph(); @@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable { LOG.info("Create DataWriter succeed."); writers.put(edge, writer); Partition partition = edge.getPartition(); - collectors.add(new OutputCollector(channelIDs, writer, partition)); + collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition)); } } @@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable { reader = new DataReader(channelIDs, inputActors, queueConf); } - RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(), - worker.getConfig(), executionNode.getParallelism()); + RuntimeContext runtimeContext = new RayRuntimeContext( + worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism()); processor.open(collectors, runtimeContext); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java index e757f14e1..593851a86 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java @@ -24,11 +24,13 @@ public abstract class BaseUnitTest { @BeforeMethod public void testBegin(Method method) { - LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>"); + LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>", + method.getDeclaringClass(), method.getName()); } @AfterMethod public void testEnd(Method method) { - LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>"); + LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>", + method.getDeclaringClass(), method.getName()); } } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java index 920bc1f74..882fc5fb4 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest { public static JobGraph buildJobGraph() { StreamingContext streamingContext = StreamingContext.buildContext(); - DataStream dataStream = DataStreamSource.buildSource(streamingContext, + DataStream dataStream = DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("a", "b", "c")); StreamSink streamSink = dataStream.sink(x -> LOG.info(x)); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java new file mode 100644 index 000000000..025f67e21 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java @@ -0,0 +1,56 @@ +package io.ray.streaming.runtime.demo; + +import io.ray.api.Ray; +import io.ray.streaming.api.context.StreamingContext; +import io.ray.streaming.api.function.impl.FilterFunction; +import io.ray.streaming.api.function.impl.MapFunction; +import io.ray.streaming.api.stream.DataStreamSource; +import io.ray.streaming.runtime.BaseUnitTest; +import java.io.Serializable; +import java.util.Arrays; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.annotations.Test; + +public class HybridStreamTest extends BaseUnitTest implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class); + + public static class Mapper1 implements MapFunction { + + @Override + public Object map(Object value) { + LOG.info("HybridStreamTest Mapper1 {}", value); + return value.toString(); + } + } + + public static class Filter1 implements FilterFunction { + + @Override + public boolean filter(Object value) throws Exception { + LOG.info("HybridStreamTest Filter1 {}", value); + return !value.toString().contains("b"); + } + } + + @Test + public void testHybridDataStream() throws InterruptedException { + Ray.shutdown(); + StreamingContext context = StreamingContext.buildContext(); + DataStreamSource streamSource = + DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c")); + streamSource + .map(x -> x + x) + .asPythonStream() + .map("ray.streaming.tests.test_hybrid_stream", "map_func1") + .filter("ray.streaming.tests.test_hybrid_stream", "filter_func1") + .asJavaStream() + .sink(x -> System.out.println("HybridStreamTest: " + x)); + context.execute("HybridStreamTestJob"); + TimeUnit.SECONDS.sleep(3); + context.stop(); + LOG.info("HybridStreamTest succeed"); + } + +} diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java index 389c1bc1a..5669ad12f 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java @@ -1,6 +1,7 @@ package io.ray.streaming.runtime.demo; import com.google.common.collect.ImmutableMap; +import io.ray.api.Ray; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.ReduceFunction; @@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable { @Test public void testWordCount() { + Ray.shutdown(); StreamingContext streamingContext = StreamingContext.buildContext(); Map config = new HashMap<>(); config.put(Config.STREAMING_BATCH_MAX_COUNT, "1"); @@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable { streamingContext.withConfig(config); List text = new ArrayList<>(); text.add("hello world eagle eagle eagle"); - DataStreamSource streamSource = DataStreamSource.buildSource(streamingContext, text); + DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text); streamSource .flatMap((FlatMapFunction) (value, collector) -> { String[] records = value.split(" "); @@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable { } } Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1)); + streamingContext.stop(); } private static class WordAndCount implements Serializable { diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java index 51440dba6..5922cc578 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java @@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python; import io.ray.streaming.api.stream.StreamSink; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobGraphBuilder; +import io.ray.streaming.runtime.serialization.MsgPackSerializer; import java.util.Arrays; import java.util.HashMap; import java.util.List; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java index 7c2e7e7ff..2e8978c7a 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java @@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest { public JobGraph buildDataSyncPlan() { StreamingContext streamingContext = StreamingContext.buildContext(); - DataStream dataStream = DataStreamSource.buildSource(streamingContext, + DataStream dataStream = DataStreamSource.fromCollection(streamingContext, Lists.newArrayList("a", "b", "c")); DataStreamSink streamSink = dataStream.sink(LOGGER::info); JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink)); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java new file mode 100644 index 000000000..be92792b6 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java @@ -0,0 +1,26 @@ +package io.ray.streaming.runtime.serialization; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import io.ray.streaming.message.KeyRecord; +import io.ray.streaming.message.Record; +import org.testng.annotations.Test; + +public class CrossLangSerializerTest { + + @Test + @SuppressWarnings("unchecked") + public void testSerialize() { + CrossLangSerializer serializer = new CrossLangSerializer(); + Record record = new Record("value"); + record.setStream("stream1"); + assertTrue(EqualsBuilder.reflectionEquals(record, + serializer.deserialize(serializer.serialize(record)))); + KeyRecord keyRecord = new KeyRecord("key", "value"); + keyRecord.setStream("stream2"); + assertEquals(keyRecord, + serializer.deserialize(serializer.serialize(keyRecord))); + } +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java similarity index 59% rename from streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java rename to streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java index b2213538b..44568df8d 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java @@ -1,4 +1,7 @@ -package io.ray.streaming.runtime.python; +package io.ray.streaming.runtime.serialization; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; import java.util.ArrayList; import java.util.Arrays; @@ -6,25 +9,37 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import org.testng.annotations.Test; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertTrue; @SuppressWarnings("unchecked") public class MsgPackSerializerTest { + @Test + public void testSerializeByte() { + MsgPackSerializer serializer = new MsgPackSerializer(); + + assertEquals(serializer.deserialize( + serializer.serialize((byte)1)), (byte)1); + } + @Test public void testSerialize() { MsgPackSerializer serializer = new MsgPackSerializer(); + assertEquals(serializer.deserialize + (serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE); + assertEquals(serializer.deserialize( + serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE); + assertEquals(serializer.deserialize( + serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE); + Map map = new HashMap(); List list = new ArrayList<>(); list.add(null); list.add(true); - list.add(1); list.add(1.0d); list.add("str"); map.put("k1", "value1"); - map.put("k2", 2); + map.put("k2", new HashMap<>()); map.put("k3", list); byte[] bytes = serializer.serialize(map); Object o = serializer.deserialize(bytes); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java index cfa34dd04..c48293cea 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -5,6 +5,7 @@ import io.ray.api.Ray; import io.ray.api.RayActor; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.ActorCreationOptions.Builder; +import io.ray.runtime.config.RayConfig; import io.ray.streaming.api.context.StreamingContext; import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.ReduceFunction; @@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); System.setProperty("ray.run-mode", "CLUSTER"); System.setProperty("ray.redirect-output", "true"); - // ray init + RayConfig.reset(); Ray.init(); } @@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { @Test(timeOut = 60000) public void testWordCount() { + Ray.shutdown(); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + // ray init + Ray.init(); LOGGER.info("testWordCount"); LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}", System.getProperty("ray.run-mode")); @@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { streamingContext.withConfig(config); List text = new ArrayList<>(); text.add("hello world eagle eagle eagle"); - DataStreamSource streamSource = DataStreamSource.buildSource(streamingContext, text); + DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text); streamSource .flatMap((FlatMapFunction) (value, collector) -> { String[] records = value.split(" "); @@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable { serializeResultToFile(resultFile, wordCount); }); - streamingContext.execute("testWordCount"); + streamingContext.execute("testSQWordCount"); Map checkWordCount = (Map) deserializeResultFromFile(resultFile); diff --git a/streaming/java/test.sh b/streaming/java/test.sh index e3225452c..ecf9770a8 100755 --- a/streaming/java/test.sh +++ b/streaming/java/test.sh @@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on echo "Running streaming tests." java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\ - org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml + org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml || exit_code=$? +if [ -z ${exit_code+x} ]; then + exit_code=0 +fi echo "Streaming TestNG results" if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then cat /tmp/ray_streaming_java_test_output/testng-results.xml diff --git a/streaming/python/collector.py b/streaming/python/collector.py index cc803eaf4..12b6c096b 100644 --- a/streaming/python/collector.py +++ b/streaming/python/collector.py @@ -1,10 +1,13 @@ import logging -import pickle import typing from abc import ABC, abstractmethod +from ray import Language +from ray.actor import ActorHandle +from ray.streaming import function from ray.streaming import message from ray.streaming import partition +from ray.streaming.runtime import serialization from ray.streaming.runtime.transfer import ChannelID, DataWriter logger = logging.getLogger(__name__) @@ -31,19 +34,46 @@ class CollectionCollector(Collector): class OutputCollector(Collector): - def __init__(self, channel_ids: typing.List[str], writer: DataWriter, + def __init__(self, writer: DataWriter, channel_ids: typing.List[str], + target_actors: typing.List[ActorHandle], partition_func: partition.Partition): - self._channel_ids = [ChannelID(id_str) for id_str in channel_ids] self._writer = writer + self._channel_ids = [ChannelID(id_str) for id_str in channel_ids] + self._target_languages = [] + for actor in target_actors: + if actor._ray_actor_language == Language.PYTHON: + self._target_languages.append(function.Language.PYTHON) + elif actor._ray_actor_language == Language.JAVA: + self._target_languages.append(function.Language.JAVA) + else: + raise Exception("Unsupported language {}" + .format(actor._ray_actor_language)) self._partition_func = partition_func + self.python_serializer = serialization.PythonSerializer() + self.cross_lang_serializer = serialization.CrossLangSerializer() logger.info( "Create OutputCollector, channel_ids {}, partition_func {}".format( channel_ids, partition_func)) def collect(self, record): - partitions = self._partition_func.partition(record, - len(self._channel_ids)) - serialized_message = pickle.dumps(record) + partitions = self._partition_func \ + .partition(record, len(self._channel_ids)) + python_buffer = None + cross_lang_buffer = None for partition_index in partitions: - self._writer.write(self._channel_ids[partition_index], - serialized_message) + if self._target_languages[partition_index] == \ + function.Language.PYTHON: + # avoid repeated serialization + if python_buffer is None: + python_buffer = self.python_serializer.serialize(record) + self._writer.write( + self._channel_ids[partition_index], + serialization._PYTHON_TYPE_ID + python_buffer) + else: + # avoid repeated serialization + if cross_lang_buffer is None: + cross_lang_buffer = self.cross_lang_serializer.serialize( + record) + self._writer.write( + self._channel_ids[partition_index], + serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer) diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py index 39a067a6a..26297da11 100644 --- a/streaming/python/datastream.py +++ b/streaming/python/datastream.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from ray.streaming import function from ray.streaming import partition @@ -19,7 +19,6 @@ class Stream(ABC): self.streaming_context = input_stream.streaming_context else: self.streaming_context = streaming_context - self.parallelism = 1 def get_streaming_context(self): return self.streaming_context @@ -29,7 +28,8 @@ class Stream(ABC): Returns: the parallelism of this transformation """ - return self.parallelism + return self._gateway_client(). \ + call_method(self._j_stream, "getParallelism") def set_parallelism(self, parallelism: int): """Sets the parallelism of this transformation @@ -40,7 +40,6 @@ class Stream(ABC): Returns: self """ - self.parallelism = parallelism self._gateway_client(). \ call_method(self._j_stream, "setParallelism", parallelism) return self @@ -60,6 +59,10 @@ class Stream(ABC): return self._gateway_client(). \ call_method(self._j_stream, "getId") + @abstractmethod + def get_language(self): + pass + def _gateway_client(self): return self.get_streaming_context()._gateway_client @@ -75,6 +78,9 @@ class DataStream(Stream): super().__init__( input_stream, j_stream, streaming_context=streaming_context) + def get_language(self): + return function.Language.PYTHON + def map(self, func): """ Applies a Map transformation on a :class:`DataStream`. @@ -158,6 +164,7 @@ class DataStream(Stream): Returns: A KeyDataStream """ + self._check_partition_call() if not isinstance(func, function.KeyFunction): func = function.SimpleKeyFunction(func) j_func = self._gateway_client().create_py_func( @@ -175,6 +182,7 @@ class DataStream(Stream): Returns: The DataStream with broadcast partitioning set. """ + self._check_partition_call() self._gateway_client().call_method(self._j_stream, "broadcast") return self @@ -191,6 +199,7 @@ class DataStream(Stream): Returns: The DataStream with specified partitioning set. """ + self._check_partition_call() if not isinstance(partition_func, partition.Partition): partition_func = partition.SimplePartition(partition_func) j_partition = self._gateway_client().create_py_func( @@ -199,6 +208,16 @@ class DataStream(Stream): call_method(self._j_stream, "partitionBy", j_partition) return self + def _check_partition_call(self): + """ + If parent stream is a java stream, we can't call partition related + methods in the python stream + """ + if self.input_stream is not None and \ + self.input_stream.get_language() == function.Language.JAVA: + raise Exception("Partition related methods can't be called on a " + "python stream if parent stream is a java stream.") + def sink(self, func): """ Create a StreamSink with the given sink. @@ -217,8 +236,97 @@ class DataStream(Stream): call_method(self._j_stream, "sink", j_func) return StreamSink(self, j_stream, func) + def as_java_stream(self): + """ + Convert this stream as a java JavaDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asJavaStream") + return JavaDataStream(self, j_stream) -class KeyDataStream(Stream): + +class JavaDataStream(Stream): + """ + Represents a stream of data which applies a transformation executed by + java. It's also a wrapper of java + `org.ray.streaming.api.stream.DataStream` + """ + + def __init__(self, input_stream, j_stream, streaming_context=None): + super().__init__( + input_stream, j_stream, streaming_context=streaming_context) + + def get_language(self): + return function.Language.JAVA + + def map(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.map""" + return JavaDataStream(self, self._unary_call("map", java_func_class)) + + def flat_map(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.flatMap""" + return JavaDataStream(self, self._unary_call("flatMap", + java_func_class)) + + def filter(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.filter""" + return JavaDataStream(self, self._unary_call("filter", + java_func_class)) + + def key_by(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.keyBy""" + self._check_partition_call() + return JavaKeyDataStream(self, + self._unary_call("keyBy", java_func_class)) + + def broadcast(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.broadcast""" + self._check_partition_call() + return JavaDataStream(self, + self._unary_call("broadcast", java_func_class)) + + def partition_by(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.partitionBy""" + self._check_partition_call() + return JavaDataStream(self, + self._unary_call("partitionBy", java_func_class)) + + def sink(self, java_func_class): + """See org.ray.streaming.api.stream.DataStream.sink""" + return JavaStreamSink(self, self._unary_call("sink", java_func_class)) + + def as_python_stream(self): + """ + Convert this stream as a python DataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asPythonStream") + return DataStream(self, j_stream) + + def _check_partition_call(self): + """ + If parent stream is a python stream, we can't call partition related + methods in the java stream + """ + if self.input_stream is not None and \ + self.input_stream.get_language() == function.Language.PYTHON: + raise Exception("Partition related methods can't be called on a" + "java stream if parent stream is a python stream.") + + def _unary_call(self, func_name, java_func_class): + j_func = self._gateway_client().new_instance(java_func_class) + j_stream = self._gateway_client(). \ + call_method(self._j_stream, func_name, j_func) + return j_stream + + +class KeyDataStream(DataStream): """Represents a DataStream returned by a key-by operation. Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream """ @@ -251,6 +359,43 @@ class KeyDataStream(Stream): call_method(self._j_stream, "reduce", j_func) return DataStream(self, j_stream) + def as_java_stream(self): + """ + Convert this stream as a java KeyDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asJavaStream") + return JavaKeyDataStream(self, j_stream) + + +class JavaKeyDataStream(JavaDataStream): + """ + Represents a DataStream returned by a key-by operation in java. + Wrapper of org.ray.streaming.api.stream.KeyDataStream + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def reduce(self, java_func_class): + """See org.ray.streaming.api.stream.KeyDataStream.reduce""" + return JavaDataStream(self, + super()._unary_call("reduce", java_func_class)) + + def as_python_stream(self): + """ + Convert this stream as a python KeyDataStream. + The converted stream and this stream are the same logical stream, + which has same stream id. Changes in converted stream will be reflected + in this stream and vice versa. + """ + j_stream = self._gateway_client(). \ + call_method(self._j_stream, "asPythonStream") + return KeyDataStream(self, j_stream) + class StreamSource(DataStream): """Represents a source of the DataStream. @@ -261,9 +406,12 @@ class StreamSource(DataStream): super().__init__(None, j_stream, streaming_context=streaming_context) self.source_func = source_func + def get_language(self): + return function.Language.PYTHON + @staticmethod def build_source(streaming_context, func): - """Build a StreamSource source from a collection. + """Build a StreamSource source from a source function. Args: streaming_context: Stream context func: A instance of `SourceFunction` @@ -275,6 +423,34 @@ class StreamSource(DataStream): return StreamSource(j_stream, streaming_context, func) +class JavaStreamSource(JavaDataStream): + """Represents a source of the java DataStream. + Wrapper of java org.ray.streaming.api.stream.DataStreamSource + """ + + def __init__(self, j_stream, streaming_context): + super().__init__(None, j_stream, streaming_context=streaming_context) + + def get_language(self): + return function.Language.JAVA + + @staticmethod + def build_source(streaming_context, java_source_func_class): + """Build a java StreamSource source from a java source function. + Args: + streaming_context: Stream context + java_source_func_class: qualified class name of java SourceFunction + Returns: + A java StreamSource + """ + j_func = streaming_context._gateway_client() \ + .new_instance(java_source_func_class) + j_stream = streaming_context._gateway_client() \ + .call_function("org.ray.streaming.api.stream.DataStreamSource" + "fromSource", streaming_context._j_ctx, j_func) + return JavaStreamSource(j_stream, streaming_context) + + class StreamSink(Stream): """Represents a sink of the DataStream. Wrapper of java io.ray.streaming.python.stream.PythonStreamSink @@ -282,3 +458,18 @@ class StreamSink(Stream): def __init__(self, input_stream, j_stream, func): super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.PYTHON + + +class JavaStreamSink(Stream): + """Represents a sink of the java DataStream. + Wrapper of java org.ray.streaming.api.stream.StreamSink + """ + + def __init__(self, input_stream, j_stream): + super().__init__(input_stream, j_stream) + + def get_language(self): + return function.Language.JAVA diff --git a/streaming/python/function.py b/streaming/python/function.py index 9a9a22a19..8d38ae6bc 100644 --- a/streaming/python/function.py +++ b/streaming/python/function.py @@ -1,13 +1,19 @@ +import enum import importlib import inspect import sys -from abc import ABC, abstractmethod import typing +from abc import ABC, abstractmethod from ray import cloudpickle from ray.streaming.runtime import gateway_client +class Language(enum.Enum): + JAVA = 0 + PYTHON = 1 + + class Function(ABC): """The base interface for all user-defined functions.""" @@ -60,6 +66,7 @@ class MapFunction(Function): for each input element. """ + @abstractmethod def map(self, value): pass @@ -70,6 +77,7 @@ class FlatMapFunction(Function): transform them into zero, one, or more elements. """ + @abstractmethod def flat_map(self, value, collector): """Takes an element from the input data set and transforms it into zero, one, or more elements. @@ -87,6 +95,7 @@ class FilterFunction(Function): The predicate decides whether to keep the element, or to discard it. """ + @abstractmethod def filter(self, value): """The filter function that evaluates the predicate. @@ -106,6 +115,7 @@ class KeyFunction(Function): deterministic key for that object. """ + @abstractmethod def key_by(self, value): """User-defined function that deterministically extracts the key from an object. @@ -126,6 +136,7 @@ class ReduceFunction(Function): them into one. """ + @abstractmethod def reduce(self, old_value, new_value): """ The core method of ReduceFunction, combining two values into one value @@ -145,6 +156,7 @@ class ReduceFunction(Function): class SinkFunction(Function): """Interface for implementing user defined sink functionality.""" + @abstractmethod def sink(self, value): """Writes the given value to the sink. This function is called for every record.""" @@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes): Returns: a streaming function """ - function_bytes, module_name, class_name, function_name, function_interface\ + assert len(descriptor_func_bytes) > 0 + function_bytes, module_name, function_name, function_interface\ = gateway_client.deserialize(descriptor_func_bytes) if function_bytes: return deserialize(function_bytes) @@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes): assert function_interface function_interface = getattr(sys.modules[__name__], function_interface) mod = importlib.import_module(module_name) - if class_name: - assert function_name is None - cls = getattr(mod, class_name) - assert issubclass(cls, function_interface) - return cls() - else: - assert function_name - func = getattr(mod, function_name) + assert function_name + func = getattr(mod, function_name) + # If func is a python function, user function is a simple python + # function, which will be wrapped as a SimpleXXXFunction. + # If func is a python class, user function is a sub class + # of XXXFunction. + if inspect.isfunction(func): simple_func_class = _get_simple_function_class(function_interface) return simple_func_class(func) + else: + assert issubclass(func, function_interface) + return func() def _get_simple_function_class(function_interface): diff --git a/streaming/python/message.py b/streaming/python/message.py index fab29d4bf..94d928e1d 100644 --- a/streaming/python/message.py +++ b/streaming/python/message.py @@ -8,6 +8,14 @@ class Record: def __repr__(self): return "Record(%s)".format(self.value) + def __eq__(self, other): + if type(self) is type(other): + return (self.stream, self.value) == (other.stream, other.value) + return False + + def __hash__(self): + return hash((self.stream, self.value)) + class KeyRecord(Record): """Data record in a keyed data stream""" @@ -15,3 +23,12 @@ class KeyRecord(Record): def __init__(self, key, value): super().__init__(value) self.key = key + + def __eq__(self, other): + if type(self) is type(other): + return (self.stream, self.key, self.value) ==\ + (other.stream, other.key, other.value) + return False + + def __hash__(self): + return hash((self.stream, self.key, self.value)) diff --git a/streaming/python/partition.py b/streaming/python/partition.py index 722fb7933..198fbe3d7 100644 --- a/streaming/python/partition.py +++ b/streaming/python/partition.py @@ -1,4 +1,5 @@ import importlib +import inspect from abc import ABC, abstractmethod from ray import cloudpickle @@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes): Returns: partition function """ - partition_bytes, module_name, class_name, function_name =\ + assert len(descriptor_partition_bytes) > 0 + partition_bytes, module_name, function_name =\ gateway_client.deserialize(descriptor_partition_bytes) if partition_bytes: return deserialize(partition_bytes) else: assert module_name mod = importlib.import_module(module_name) - # If class_name is not None, user partition is a sub class - # of Partition. - # If function_name is not None, user partition is a simple python + assert function_name + func = getattr(mod, function_name) + # If func is a python function, user partition is a simple python # function, which will be wrapped as a SimplePartition. - if class_name: - assert function_name is None - cls = getattr(mod, class_name) - return cls() - else: - assert function_name - func = getattr(mod, function_name) + # If func is a python class, user partition is a sub class + # of Partition. + if inspect.isfunction(func): return SimplePartition(func) + else: + assert issubclass(func, Partition) + return func() diff --git a/streaming/python/runtime/gateway_client.py b/streaming/python/runtime/gateway_client.py index 5477d9230..8fa4fac61 100644 --- a/streaming/python/runtime/gateway_client.py +++ b/streaming/python/runtime/gateway_client.py @@ -55,6 +55,11 @@ class GatewayClient: call = self._python_gateway_actor.callMethod.remote(java_params) return deserialize(ray.get(call)) + def new_instance(self, java_class_name): + call = self._python_gateway_actor.newInstance.remote( + serialize(java_class_name)) + return deserialize(ray.get(call)) + def serialize(obj) -> bytes: """Serialize a python object which can be deserialized by `PythonGateway` diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py index 78396dce5..645827601 100644 --- a/streaming/python/runtime/graph.py +++ b/streaming/python/runtime/graph.py @@ -53,7 +53,9 @@ class ExecutionEdge: self.src_node_id = edge_pb.src_node_id self.target_node_id = edge_pb.target_node_id partition_bytes = edge_pb.partition - if language == Language.PYTHON: + # Sink node doesn't have partition function, + # so we only deserialize partition_bytes when it's not None or empty + if language == Language.PYTHON and partition_bytes: self.partition = partition.load_partition(partition_bytes) diff --git a/streaming/python/runtime/serialization.py b/streaming/python/runtime/serialization.py new file mode 100644 index 000000000..a01bf4e2c --- /dev/null +++ b/streaming/python/runtime/serialization.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +import pickle +import msgpack +from ray.streaming import message + +_RECORD_TYPE_ID = 0 +_KEY_RECORD_TYPE_ID = 1 +_CROSS_LANG_TYPE_ID = b"0" +_JAVA_TYPE_ID = b"1" +_PYTHON_TYPE_ID = b"2" + + +class Serializer(ABC): + @abstractmethod + def serialize(self, obj): + pass + + @abstractmethod + def deserialize(self, serialized_bytes): + pass + + +class PythonSerializer(Serializer): + def serialize(self, obj): + return pickle.dumps(obj) + + def deserialize(self, serialized_bytes): + return pickle.loads(serialized_bytes) + + +class CrossLangSerializer(Serializer): + """Serialize stream element between java/python""" + + def serialize(self, obj): + if type(obj) is message.Record: + fields = [_RECORD_TYPE_ID, obj.stream, obj.value] + elif type(obj) is message.KeyRecord: + fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] + else: + raise Exception("Unsupported value {}".format(obj)) + return msgpack.packb(fields, use_bin_type=True) + + def deserialize(self, data): + fields = msgpack.unpackb(data, raw=False) + if fields[0] == _RECORD_TYPE_ID: + stream, value = fields[1:] + record = message.Record(value) + record.stream = stream + return record + elif fields[0] == _KEY_RECORD_TYPE_ID: + stream, key, value = fields[1:] + key_record = message.KeyRecord(key, value) + key_record.stream = stream + return key_record + else: + raise Exception("Unsupported type id {}, type {}".format( + fields[0], type(fields[0]))) diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py index ee0aeb561..c207c4727 100644 --- a/streaming/python/runtime/task.py +++ b/streaming/python/runtime/task.py @@ -1,11 +1,13 @@ import logging -import pickle import threading from abc import ABC, abstractmethod from ray.streaming.collector import OutputCollector from ray.streaming.config import Config from ray.streaming.context import RuntimeContextImpl +from ray.streaming.runtime import serialization +from ray.streaming.runtime.serialization import \ + PythonSerializer, CrossLangSerializer from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader logger = logging.getLogger(__name__) @@ -38,36 +40,40 @@ class StreamTask(ABC): # writers collectors = [] for edge in execution_node.output_edges: - output_actor_ids = {} + output_actors_map = {} task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( edge.target_node_id) for target_task_id, target_actor in task_id2_worker.items(): channel_name = ChannelID.gen_id(self.task_id, target_task_id, execution_graph.build_time()) - output_actor_ids[channel_name] = target_actor - if len(output_actor_ids) > 0: - channel_ids = list(output_actor_ids.keys()) - to_actor_ids = list(output_actor_ids.values()) - writer = DataWriter(channel_ids, to_actor_ids, channel_conf) - logger.info("Create DataWriter succeed.") + output_actors_map[channel_name] = target_actor + if len(output_actors_map) > 0: + channel_ids = list(output_actors_map.keys()) + target_actors = list(output_actors_map.values()) + logger.info( + "Create DataWriter channel_ids {}, target_actors {}." + .format(channel_ids, target_actors)) + writer = DataWriter(channel_ids, target_actors, channel_conf) self.writers[edge] = writer collectors.append( - OutputCollector(channel_ids, writer, edge.partition)) + OutputCollector(writer, channel_ids, target_actors, + edge.partition)) # readers - input_actor_ids = {} + input_actor_map = {} for edge in execution_node.input_edges: task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( edge.src_node_id) for src_task_id, src_actor in task_id2_worker.items(): channel_name = ChannelID.gen_id(src_task_id, self.task_id, execution_graph.build_time()) - input_actor_ids[channel_name] = src_actor - if len(input_actor_ids) > 0: - channel_ids = list(input_actor_ids.keys()) - from_actor_ids = list(input_actor_ids.values()) - logger.info("Create DataReader, channels {}.".format(channel_ids)) - self.reader = DataReader(channel_ids, from_actor_ids, channel_conf) + input_actor_map[channel_name] = src_actor + if len(input_actor_map) > 0: + channel_ids = list(input_actor_map.keys()) + from_actors = list(input_actor_map.values()) + logger.info("Create DataReader, channels {}, input_actors {}." + .format(channel_ids, from_actors)) + self.reader = DataReader(channel_ids, from_actors, channel_conf) def exit_handler(): # Make DataReader stop read data when MockQueue destructor @@ -111,6 +117,8 @@ class InputStreamTask(StreamTask): self.read_timeout_millis = \ int(worker.config.get(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS)) + self.python_serializer = PythonSerializer() + self.cross_lang_serializer = CrossLangSerializer() def init(self): pass @@ -120,7 +128,11 @@ class InputStreamTask(StreamTask): item = self.reader.read(self.read_timeout_millis) if item is not None: msg_data = item.body() - msg = pickle.loads(msg_data) + type_id = msg_data[:1] + if (type_id == serialization._PYTHON_TYPE_ID): + msg = self.python_serializer.deserialize(msg_data[1:]) + else: + msg = self.cross_lang_serializer.deserialize(msg_data[1:]) self.processor.process(msg) self.stopped = True diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py index f40ea087a..a6beb03de 100644 --- a/streaming/python/runtime/transfer.py +++ b/streaming/python/runtime/transfer.py @@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder: wrap initial parameters needed by a streaming queue """ _java_reader_async_function_descriptor = JavaFunctionDescriptor( - "io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V") + "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage", + "([B)V") _java_reader_sync_function_descriptor = JavaFunctionDescriptor( - "io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B") + "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync", + "([B)[B") _java_writer_async_function_descriptor = JavaFunctionDescriptor( - "io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V") + "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage", + "([B)V") _java_writer_sync_function_descriptor = JavaFunctionDescriptor( - "io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B") + "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync", + "([B)[B") _python_reader_async_function_descriptor = PythonFunctionDescriptor( "ray.streaming.runtime.worker", "on_reader_message", "JobWorker") _python_reader_sync_function_descriptor = PythonFunctionDescriptor( diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py index 9743205ef..86e88a0f7 100644 --- a/streaming/python/runtime/worker.py +++ b/streaming/python/runtime/worker.py @@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask logger = logging.getLogger(__name__) +# special flag to indicate this actor not ready +_NOT_READY_FLAG_ = b" " * 4 + @ray.remote class JobWorker(object): @@ -66,23 +69,31 @@ class JobWorker(object): type(self.stream_processor)) def on_reader_message(self, buffer: bytes): - """used in direct call mode""" + """Called by upstream queue writer to send data message to downstream + queue reader. + """ self.reader_client.on_reader_message(buffer) def on_reader_message_sync(self, buffer: bytes): - """used in direct call mode""" + """Called by upstream queue writer to send control message to downstream + downstream queue reader. + """ if self.reader_client is None: - return b" " * 4 # special flag to indicate this actor not ready + return _NOT_READY_FLAG_ result = self.reader_client.on_reader_message_sync(buffer) return result.to_pybytes() def on_writer_message(self, buffer: bytes): - """used in direct call mode""" + """Called by downstream queue reader to send notify message to + upstream queue writer. + """ self.writer_client.on_writer_message(buffer) def on_writer_message_sync(self, buffer: bytes): - """used in direct call mode""" + """Called by downstream queue reader to send control message to + upstream queue writer. + """ if self.writer_client is None: - return b" " * 4 # special flag to indicate this actor not ready + return _NOT_READY_FLAG_ result = self.writer_client.on_writer_message_sync(buffer) return result.to_pybytes() diff --git a/streaming/python/tests/test_function.py b/streaming/python/tests/test_function.py index 3564a1698..c9ce33067 100644 --- a/streaming/python/tests/test_function.py +++ b/streaming/python/tests/test_function.py @@ -14,9 +14,9 @@ class MapFunc(function.MapFunction): def test_load_function(): - # function_bytes, module_name, class_name, function_name, + # function_bytes, module_name, function_name/class_name, # function_interface descriptor_func_bytes = gateway_client.serialize( - [None, __name__, MapFunc.__name__, None, "MapFunction"]) + [None, __name__, MapFunc.__name__, "MapFunction"]) func = function.load_function(descriptor_func_bytes) assert type(func) is MapFunc diff --git a/streaming/python/tests/test_hybrid_stream.py b/streaming/python/tests/test_hybrid_stream.py new file mode 100644 index 000000000..a103be435 --- /dev/null +++ b/streaming/python/tests/test_hybrid_stream.py @@ -0,0 +1,70 @@ +import json +import ray +from ray.streaming import StreamingContext +import subprocess +import os + + +def map_func1(x): + print("HybridStreamTest map_func1", x) + return str(x) + + +def filter_func1(x): + print("HybridStreamTest filter_func1", x) + return "b" not in x + + +def sink_func1(x): + print("HybridStreamTest sink_func1 value:", x) + + +def test_hybrid_stream(): + subprocess.check_call( + ["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"]) + current_dir = os.path.abspath(os.path.dirname(__file__)) + jar_path = os.path.join( + current_dir, + "../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar") + jar_path = os.path.abspath(jar_path) + print("jar_path", jar_path) + java_worker_options = json.dumps(["-classpath", jar_path]) + print("java_worker_options", java_worker_options) + assert not ray.is_initialized() + ray.init( + load_code_from_local=True, + include_java=True, + java_worker_options=java_worker_options, + _internal_config=json.dumps({ + "num_workers_per_process_java": 1 + })) + + sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt" + if os.path.exists(sink_file): + os.remove(sink_file) + + def sink_func(x): + print("HybridStreamTest", x) + with open(sink_file, "a") as f: + f.write(str(x)) + + ctx = StreamingContext.Builder().build() + ctx.from_values("a", "b", "c") \ + .as_java_stream() \ + .map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \ + .filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \ + .as_python_stream() \ + .sink(sink_func) + ctx.submit("HybridStreamTest") + import time + time.sleep(3) + ray.shutdown() + with open(sink_file, "r") as f: + result = f.read() + assert "a" in result + assert "b" not in result + assert "c" in result + + +if __name__ == "__main__": + test_hybrid_stream() diff --git a/streaming/python/tests/test_serialization.py b/streaming/python/tests/test_serialization.py new file mode 100644 index 000000000..67865f802 --- /dev/null +++ b/streaming/python/tests/test_serialization.py @@ -0,0 +1,13 @@ +from ray.streaming.runtime.serialization import CrossLangSerializer +from ray.streaming.message import Record, KeyRecord + + +def test_serialize(): + serializer = CrossLangSerializer() + record = Record("value") + record.stream = "stream1" + key_record = KeyRecord("key", "value") + key_record.stream = "stream2" + assert record == serializer.deserialize(serializer.serialize(record)) + assert key_record == serializer.\ + deserialize(serializer.serialize(key_record)) diff --git a/streaming/python/tests/test_stream.py b/streaming/python/tests/test_stream.py new file mode 100644 index 000000000..8eb0fbe6a --- /dev/null +++ b/streaming/python/tests/test_stream.py @@ -0,0 +1,31 @@ +import ray +from ray.streaming import StreamingContext + + +def test_data_stream(): + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder().build() + stream = ctx.from_values(1, 2, 3) + java_stream = stream.as_java_stream() + python_stream = java_stream.as_python_stream() + assert stream.get_id() == java_stream.get_id() + assert stream.get_id() == python_stream.get_id() + python_stream.set_parallelism(10) + assert stream.get_parallelism() == java_stream.get_parallelism() + assert stream.get_parallelism() == python_stream.get_parallelism() + ray.shutdown() + + +def test_key_data_stream(): + ray.init(load_code_from_local=True, include_java=True) + ctx = StreamingContext.Builder().build() + key_stream = ctx.from_values( + "a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0]) + java_stream = key_stream.as_java_stream() + python_stream = java_stream.as_python_stream() + assert key_stream.get_id() == java_stream.get_id() + assert key_stream.get_id() == python_stream.get_id() + python_stream.set_parallelism(10) + assert key_stream.get_parallelism() == java_stream.get_parallelism() + assert key_stream.get_parallelism() == python_stream.get_parallelism() + ray.shutdown() diff --git a/streaming/python/tests/test_word_count.py b/streaming/python/tests/test_word_count.py index d86595cf4..03d9d7652 100644 --- a/streaming/python/tests/test_word_count.py +++ b/streaming/python/tests/test_word_count.py @@ -32,7 +32,9 @@ def test_simple_word_count(): def sink_func(x): with open(sink_file, "a") as f: - f.write("{}:{},".format(x[0], x[1])) + line = "{}:{},".format(x[0], x[1]) + print("sink_func", line) + f.write(line) ctx.from_values("a", "b", "c") \ .set_parallelism(1) \ diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc index e2cb2e861..98acc36a1 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc @@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative( return reinterpret_cast(reader_client); } +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *writer_client = reinterpret_cast(ptr); + writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes)); +} + JNIEXPORT jbyteArray JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {