mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Streaming] Streaming Cross-Lang API (#7464)
This commit is contained in:
parent
101255f782
commit
91f630f709
72 changed files with 1612 additions and 408 deletions
|
@ -542,6 +542,7 @@ def init(address=None,
|
||||||
raylet_socket_name=None,
|
raylet_socket_name=None,
|
||||||
temp_dir=None,
|
temp_dir=None,
|
||||||
load_code_from_local=False,
|
load_code_from_local=False,
|
||||||
|
java_worker_options=None,
|
||||||
use_pickle=True,
|
use_pickle=True,
|
||||||
_internal_config=None,
|
_internal_config=None,
|
||||||
lru_evict=False):
|
lru_evict=False):
|
||||||
|
@ -651,6 +652,7 @@ def init(address=None,
|
||||||
conventional location, e.g., "/tmp/ray".
|
conventional location, e.g., "/tmp/ray".
|
||||||
load_code_from_local: Whether code should be loaded from a local
|
load_code_from_local: Whether code should be loaded from a local
|
||||||
module or from the GCS.
|
module or from the GCS.
|
||||||
|
java_worker_options: Overwrite the options to start Java workers.
|
||||||
use_pickle: Deprecated.
|
use_pickle: Deprecated.
|
||||||
_internal_config (str): JSON configuration for overriding
|
_internal_config (str): JSON configuration for overriding
|
||||||
RayConfig defaults. For testing purposes ONLY.
|
RayConfig defaults. For testing purposes ONLY.
|
||||||
|
@ -758,6 +760,7 @@ def init(address=None,
|
||||||
raylet_socket_name=raylet_socket_name,
|
raylet_socket_name=raylet_socket_name,
|
||||||
temp_dir=temp_dir,
|
temp_dir=temp_dir,
|
||||||
load_code_from_local=load_code_from_local,
|
load_code_from_local=load_code_from_local,
|
||||||
|
java_worker_options=java_worker_options,
|
||||||
_internal_config=_internal_config,
|
_internal_config=_internal_config,
|
||||||
)
|
)
|
||||||
# Start the Ray processes. We set shutdown_at_exit=False because we
|
# Start the Ray processes. We set shutdown_at_exit=False because we
|
||||||
|
@ -808,6 +811,9 @@ def init(address=None,
|
||||||
if raylet_socket_name is not None:
|
if raylet_socket_name is not None:
|
||||||
raise ValueError("When connecting to an existing cluster, "
|
raise ValueError("When connecting to an existing cluster, "
|
||||||
"raylet_socket_name must not be provided.")
|
"raylet_socket_name must not be provided.")
|
||||||
|
if java_worker_options is not None:
|
||||||
|
raise ValueError("When connecting to an existing cluster, "
|
||||||
|
"java_worker_options must not be provided.")
|
||||||
if _internal_config is not None and len(_internal_config) != 0:
|
if _internal_config is not None and len(_internal_config) != 0:
|
||||||
raise ValueError("When connecting to an existing cluster, "
|
raise ValueError("When connecting to an existing cluster, "
|
||||||
"_internal_config must not be provided.")
|
"_internal_config must not be provided.")
|
||||||
|
|
|
@ -39,6 +39,7 @@ define_java_module(
|
||||||
":io_ray_ray_streaming-state",
|
":io_ray_ray_streaming-state",
|
||||||
":io_ray_ray_streaming-api",
|
":io_ray_ray_streaming-api",
|
||||||
"@ray_streaming_maven//:com_google_guava_guava",
|
"@ray_streaming_maven//:com_google_guava_guava",
|
||||||
|
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||||
"@ray_streaming_maven//:org_testng_testng",
|
"@ray_streaming_maven//:org_testng_testng",
|
||||||
|
@ -46,7 +47,12 @@ define_java_module(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":io_ray_ray_streaming-state",
|
":io_ray_ray_streaming-state",
|
||||||
|
"//java:io_ray_ray_api",
|
||||||
|
"//java:io_ray_ray_runtime",
|
||||||
|
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||||
|
"@ray_streaming_maven//:com_google_code_gson_gson",
|
||||||
"@ray_streaming_maven//:com_google_guava_guava",
|
"@ray_streaming_maven//:com_google_guava_guava",
|
||||||
|
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||||
],
|
],
|
||||||
|
@ -129,8 +135,9 @@ define_java_module(
|
||||||
":io_ray_ray_streaming-api",
|
":io_ray_ray_streaming-api",
|
||||||
":io_ray_ray_streaming-runtime",
|
":io_ray_ray_streaming-runtime",
|
||||||
"@ray_streaming_maven//:com_google_guava_guava",
|
"@ray_streaming_maven//:com_google_guava_guava",
|
||||||
|
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||||
|
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||||
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
||||||
"@ray_streaming_maven//:org_msgpack_msgpack_core",
|
|
||||||
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||||
|
@ -146,10 +153,12 @@ define_java_module(
|
||||||
"//java:io_ray_ray_api",
|
"//java:io_ray_ray_api",
|
||||||
"//java:io_ray_ray_runtime",
|
"//java:io_ray_ray_runtime",
|
||||||
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
|
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
|
||||||
|
"@ray_streaming_maven//:com_google_code_findbugs_jsr305",
|
||||||
"@ray_streaming_maven//:com_google_guava_guava",
|
"@ray_streaming_maven//:com_google_guava_guava",
|
||||||
"@ray_streaming_maven//:com_google_protobuf_protobuf_java",
|
"@ray_streaming_maven//:com_google_protobuf_protobuf_java",
|
||||||
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
"@ray_streaming_maven//:de_ruedigermoeller_fst",
|
||||||
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
"@ray_streaming_maven//:org_aeonbits_owner_owner",
|
||||||
|
"@ray_streaming_maven//:org_apache_commons_commons_lang3",
|
||||||
"@ray_streaming_maven//:org_msgpack_msgpack_core",
|
"@ray_streaming_maven//:org_msgpack_msgpack_core",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
"@ray_streaming_maven//:org_slf4j_slf4j_api",
|
||||||
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
|
||||||
|
|
|
@ -6,8 +6,11 @@ def gen_streaming_java_deps():
|
||||||
artifacts = [
|
artifacts = [
|
||||||
"com.beust:jcommander:1.72",
|
"com.beust:jcommander:1.72",
|
||||||
"com.google.guava:guava:27.0.1-jre",
|
"com.google.guava:guava:27.0.1-jre",
|
||||||
|
"com.google.code.findbugs:jsr305:3.0.2",
|
||||||
|
"com.google.code.gson:gson:2.8.5",
|
||||||
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
|
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
|
||||||
"com.google.protobuf:protobuf-java:3.8.0",
|
"com.google.protobuf:protobuf-java:3.8.0",
|
||||||
|
"org.apache.commons:commons-lang3:3.4",
|
||||||
"de.ruedigermoeller:fst:2.57",
|
"de.ruedigermoeller:fst:2.57",
|
||||||
"org.aeonbits.owner:owner:1.0.10",
|
"org.aeonbits.owner:owner:1.0.10",
|
||||||
"org.slf4j:slf4j-api:1.7.12",
|
"org.slf4j:slf4j-api:1.7.12",
|
||||||
|
@ -22,7 +25,6 @@ def gen_streaming_java_deps():
|
||||||
"org.mockito:mockito-all:1.10.19",
|
"org.mockito:mockito-all:1.10.19",
|
||||||
"org.powermock:powermock-module-testng:1.6.6",
|
"org.powermock:powermock-module-testng:1.6.6",
|
||||||
"org.powermock:powermock-api-mockito:1.6.6",
|
"org.powermock:powermock-api-mockito:1.6.6",
|
||||||
"org.projectlombok:lombok:1.16.20",
|
|
||||||
],
|
],
|
||||||
repositories = [
|
repositories = [
|
||||||
"https://repo1.maven.org/maven2/",
|
"https://repo1.maven.org/maven2/",
|
||||||
|
|
|
@ -22,16 +22,36 @@
|
||||||
<artifactId>ray-api</artifactId>
|
<artifactId>ray-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-runtime</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.ray</groupId>
|
<groupId>org.ray</groupId>
|
||||||
<artifactId>streaming-state</artifactId>
|
<artifactId>streaming-state</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
<artifactId>jsr305</artifactId>
|
||||||
|
<version>3.0.2</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.gson</groupId>
|
||||||
|
<artifactId>gson</artifactId>
|
||||||
|
<version>2.8.5</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
<groupId>com.google.guava</groupId>
|
<groupId>com.google.guava</groupId>
|
||||||
<artifactId>guava</artifactId>
|
<artifactId>guava</artifactId>
|
||||||
<version>27.0.1-jre</version>
|
<version>27.0.1-jre</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>3.4</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.slf4j</groupId>
|
<groupId>org.slf4j</groupId>
|
||||||
<artifactId>slf4j-api</artifactId>
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
|
|
@ -22,6 +22,11 @@
|
||||||
<artifactId>ray-api</artifactId>
|
<artifactId>ray-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.ray</groupId>
|
||||||
|
<artifactId>ray-runtime</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.ray</groupId>
|
<groupId>org.ray</groupId>
|
||||||
<artifactId>streaming-state</artifactId>
|
<artifactId>streaming-state</artifactId>
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
package io.ray.streaming.api.context;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import com.google.common.collect.ImmutableList;
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import io.ray.api.Ray;
|
||||||
|
import io.ray.runtime.config.RayConfig;
|
||||||
|
import io.ray.runtime.util.NetworkUtil;
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
class ClusterStarter {
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
|
||||||
|
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
|
||||||
|
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
|
||||||
|
|
||||||
|
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
|
||||||
|
Preconditions.checkArgument(Ray.internal() == null);
|
||||||
|
RayConfig.reset();
|
||||||
|
if (!isLocal) {
|
||||||
|
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||||
|
System.setProperty("ray.run-mode", "CLUSTER");
|
||||||
|
} else {
|
||||||
|
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
|
||||||
|
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isCrossLanguage) {
|
||||||
|
Ray.init();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete existing socket files.
|
||||||
|
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||||
|
File file = new File(socket);
|
||||||
|
if (file.exists()) {
|
||||||
|
LOG.info("Delete existing socket file {}", file);
|
||||||
|
file.delete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
|
||||||
|
|
||||||
|
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
|
||||||
|
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
|
||||||
|
// with `test` included in the name, we can check classpath `test` to filter out test classes.
|
||||||
|
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
|
||||||
|
.filter(s -> !s.contains(" ") && s.contains("test"))
|
||||||
|
.collect(Collectors.joining(":"));
|
||||||
|
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
|
||||||
|
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
|
||||||
|
config.put("num_workers_per_process_java", "1");
|
||||||
|
// Start ray cluster.
|
||||||
|
List<String> startCommand = ImmutableList.of(
|
||||||
|
"ray",
|
||||||
|
"start",
|
||||||
|
"--head",
|
||||||
|
"--redis-port=6379",
|
||||||
|
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||||
|
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||||
|
String.format("--node-manager-port=%s", nodeManagerPort),
|
||||||
|
"--load-code-from-local",
|
||||||
|
"--include-java",
|
||||||
|
"--java-worker-options=" + workerOptions,
|
||||||
|
"--internal-config=" + new Gson().toJson(config)
|
||||||
|
);
|
||||||
|
if (!executeCommand(startCommand, 10)) {
|
||||||
|
throw new RuntimeException("Couldn't start ray cluster.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the cluster.
|
||||||
|
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||||
|
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||||
|
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||||
|
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
|
||||||
|
Ray.init();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static synchronized void stopCluster(boolean isCrossLanguage) {
|
||||||
|
// Disconnect to the cluster.
|
||||||
|
Ray.shutdown();
|
||||||
|
System.clearProperty("ray.redis.address");
|
||||||
|
System.clearProperty("ray.object-store.socket-name");
|
||||||
|
System.clearProperty("ray.raylet.socket-name");
|
||||||
|
System.clearProperty("ray.raylet.node-manager-port");
|
||||||
|
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
|
||||||
|
System.clearProperty("ray.run-mode");
|
||||||
|
|
||||||
|
if (isCrossLanguage) {
|
||||||
|
// Stop ray cluster.
|
||||||
|
final List<String> stopCommand = ImmutableList.of(
|
||||||
|
"ray",
|
||||||
|
"stop"
|
||||||
|
);
|
||||||
|
if (!executeCommand(stopCommand, 10)) {
|
||||||
|
throw new RuntimeException("Couldn't stop ray cluster");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute an external command.
|
||||||
|
*
|
||||||
|
* @return Whether the command succeeded.
|
||||||
|
*/
|
||||||
|
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
|
||||||
|
LOG.info("Executing command: {}", String.join(" ", command));
|
||||||
|
try {
|
||||||
|
ProcessBuilder processBuilder = new ProcessBuilder(command)
|
||||||
|
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
|
||||||
|
.redirectError(ProcessBuilder.Redirect.INHERIT);
|
||||||
|
Process process = processBuilder.start();
|
||||||
|
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||||
|
if (!exit) {
|
||||||
|
process.destroyForcibly();
|
||||||
|
}
|
||||||
|
return process.exitValue() == 0;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,12 @@
|
||||||
package io.ray.streaming.api.context;
|
package io.ray.streaming.api.context;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
import io.ray.api.Ray;
|
||||||
import io.ray.streaming.api.stream.StreamSink;
|
import io.ray.streaming.api.stream.StreamSink;
|
||||||
import io.ray.streaming.jobgraph.JobGraph;
|
import io.ray.streaming.jobgraph.JobGraph;
|
||||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||||
import io.ray.streaming.schedule.JobScheduler;
|
import io.ray.streaming.schedule.JobScheduler;
|
||||||
|
import io.ray.streaming.util.Config;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
@ -13,11 +15,14 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.ServiceLoader;
|
import java.util.ServiceLoader;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Encapsulate the context information of a streaming Job.
|
* Encapsulate the context information of a streaming Job.
|
||||||
*/
|
*/
|
||||||
public class StreamingContext implements Serializable {
|
public class StreamingContext implements Serializable {
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
|
||||||
|
|
||||||
private transient AtomicInteger idGenerator;
|
private transient AtomicInteger idGenerator;
|
||||||
|
|
||||||
|
@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
|
||||||
this.jobGraph = jobGraphBuilder.build();
|
this.jobGraph = jobGraphBuilder.build();
|
||||||
jobGraph.printJobGraph();
|
jobGraph.printJobGraph();
|
||||||
|
|
||||||
|
if (Ray.internal() == null) {
|
||||||
|
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
|
||||||
|
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
|
||||||
|
ClusterStarter.startCluster(false, true);
|
||||||
|
LOG.info("Created local cluster for job {}.", jobName);
|
||||||
|
} else {
|
||||||
|
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
|
||||||
|
LOG.info("Created multi process cluster for job {}.", jobName);
|
||||||
|
}
|
||||||
|
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
|
||||||
|
} else {
|
||||||
|
LOG.info("Reuse existing cluster.");
|
||||||
|
}
|
||||||
|
|
||||||
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
|
ServiceLoader<JobScheduler> serviceLoader = ServiceLoader.load(JobScheduler.class);
|
||||||
Iterator<JobScheduler> iterator = serviceLoader.iterator();
|
Iterator<JobScheduler> iterator = serviceLoader.iterator();
|
||||||
Preconditions.checkArgument(iterator.hasNext(),
|
Preconditions.checkArgument(iterator.hasNext(),
|
||||||
|
@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
|
||||||
public void withConfig(Map<String, String> jobConfig) {
|
public void withConfig(Map<String, String> jobConfig) {
|
||||||
this.jobConfig = jobConfig;
|
this.jobConfig = jobConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void stop() {
|
||||||
|
if (Ray.internal() != null) {
|
||||||
|
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package io.ray.streaming.api.stream;
|
package io.ray.streaming.api.stream;
|
||||||
|
|
||||||
|
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
import io.ray.streaming.api.function.impl.FilterFunction;
|
import io.ray.streaming.api.function.impl.FilterFunction;
|
||||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||||
|
@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
|
||||||
import io.ray.streaming.operator.impl.KeyByOperator;
|
import io.ray.streaming.operator.impl.KeyByOperator;
|
||||||
import io.ray.streaming.operator.impl.MapOperator;
|
import io.ray.streaming.operator.impl.MapOperator;
|
||||||
import io.ray.streaming.operator.impl.SinkOperator;
|
import io.ray.streaming.operator.impl.SinkOperator;
|
||||||
|
import io.ray.streaming.python.stream.PythonDataStream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a stream of data.
|
* Represents a stream of data.
|
||||||
*
|
* <p>This class defines all the streaming operations.
|
||||||
* This class defines all the streaming operations.
|
|
||||||
*
|
*
|
||||||
* @param <T> Type of data in the stream.
|
* @param <T> Type of data in the stream.
|
||||||
*/
|
*/
|
||||||
public class DataStream<T> extends Stream<T> {
|
public class DataStream<T> extends Stream<DataStream<T>, T> {
|
||||||
|
|
||||||
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||||
super(streamingContext, streamOperator);
|
super(streamingContext, streamOperator);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataStream(DataStream input, StreamOperator streamOperator) {
|
public DataStream(StreamingContext streamingContext,
|
||||||
|
StreamOperator streamOperator,
|
||||||
|
Partition<T> partition) {
|
||||||
|
super(streamingContext, streamOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <R> DataStream(DataStream<R> input, StreamOperator streamOperator) {
|
||||||
super(input, streamOperator);
|
super(input, streamOperator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public <R> DataStream(DataStream<R> input,
|
||||||
|
StreamOperator streamOperator,
|
||||||
|
Partition<T> partition) {
|
||||||
|
super(input, streamOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a java stream that reference passed python stream.
|
||||||
|
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||||
|
*/
|
||||||
|
public DataStream(PythonDataStream referencedStream) {
|
||||||
|
super(referencedStream);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a map function to this stream.
|
* Apply a map function to this stream.
|
||||||
*
|
*
|
||||||
|
@ -41,7 +62,7 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return A new DataStream.
|
* @return A new DataStream.
|
||||||
*/
|
*/
|
||||||
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
|
public <R> DataStream<R> map(MapFunction<T, R> mapFunction) {
|
||||||
return new DataStream<>(this, new MapOperator(mapFunction));
|
return new DataStream<>(this, new MapOperator<>(mapFunction));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -52,11 +73,11 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return A new DataStream
|
* @return A new DataStream
|
||||||
*/
|
*/
|
||||||
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
|
public <R> DataStream<R> flatMap(FlatMapFunction<T, R> flatMapFunction) {
|
||||||
return new DataStream(this, new FlatMapOperator(flatMapFunction));
|
return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataStream<T> filter(FilterFunction<T> filterFunction) {
|
public DataStream<T> filter(FilterFunction<T> filterFunction) {
|
||||||
return new DataStream<T>(this, new FilterOperator(filterFunction));
|
return new DataStream<>(this, new FilterOperator<>(filterFunction));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -66,7 +87,7 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return A new UnionStream.
|
* @return A new UnionStream.
|
||||||
*/
|
*/
|
||||||
public UnionStream<T> union(DataStream<T> other) {
|
public UnionStream<T> union(DataStream<T> other) {
|
||||||
return new UnionStream(this, null, other);
|
return new UnionStream<>(this, null, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -93,7 +114,7 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return A new StreamSink.
|
* @return A new StreamSink.
|
||||||
*/
|
*/
|
||||||
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
|
public DataStreamSink<T> sink(SinkFunction<T> sinkFunction) {
|
||||||
return new DataStreamSink<>(this, new SinkOperator(sinkFunction));
|
return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -104,7 +125,8 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return A new KeyDataStream.
|
* @return A new KeyDataStream.
|
||||||
*/
|
*/
|
||||||
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
|
public <K> KeyDataStream<K, T> keyBy(KeyFunction<T, K> keyFunction) {
|
||||||
return new KeyDataStream<>(this, new KeyByOperator(keyFunction));
|
checkPartitionCall();
|
||||||
|
return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -113,8 +135,8 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return This stream.
|
* @return This stream.
|
||||||
*/
|
*/
|
||||||
public DataStream<T> broadcast() {
|
public DataStream<T> broadcast() {
|
||||||
this.partition = new BroadcastPartition<>();
|
checkPartitionCall();
|
||||||
return this;
|
return setPartition(new BroadcastPartition<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -124,19 +146,32 @@ public class DataStream<T> extends Stream<T> {
|
||||||
* @return This stream.
|
* @return This stream.
|
||||||
*/
|
*/
|
||||||
public DataStream<T> partitionBy(Partition<T> partition) {
|
public DataStream<T> partitionBy(Partition<T> partition) {
|
||||||
this.partition = partition;
|
checkPartitionCall();
|
||||||
return this;
|
return setPartition(partition);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set parallelism to current transformation.
|
* If parent stream is a python stream, we can't call partition related methods
|
||||||
*
|
* in the java stream.
|
||||||
* @param parallelism The parallelism to set.
|
|
||||||
* @return This stream.
|
|
||||||
*/
|
*/
|
||||||
public DataStream<T> setParallelism(int parallelism) {
|
private void checkPartitionCall() {
|
||||||
this.parallelism = parallelism;
|
if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
|
||||||
return this;
|
throw new RuntimeException("Partition related methods can't be called on a " +
|
||||||
|
"java stream if parent stream is a python stream.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert this stream as a python stream.
|
||||||
|
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||||
|
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||||
|
*/
|
||||||
|
public PythonDataStream asPythonStream() {
|
||||||
|
return new PythonDataStream(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Language getLanguage() {
|
||||||
|
return Language.JAVA;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package io.ray.streaming.api.stream;
|
package io.ray.streaming.api.stream;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.operator.impl.SinkOperator;
|
import io.ray.streaming.operator.impl.SinkOperator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
|
||||||
*/
|
*/
|
||||||
public class DataStreamSink<T> extends StreamSink<T> {
|
public class DataStreamSink<T> extends StreamSink<T> {
|
||||||
|
|
||||||
public DataStreamSink(DataStream<T> input, SinkOperator sinkOperator) {
|
public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
|
||||||
super(input, sinkOperator);
|
super(input, sinkOperator);
|
||||||
this.streamingContext.addSink(this);
|
getStreamingContext().addSink(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataStreamSink<T> setParallelism(int parallelism) {
|
@Override
|
||||||
this.parallelism = parallelism;
|
public Language getLanguage() {
|
||||||
return this;
|
return Language.JAVA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,9 +14,13 @@ import java.util.Collection;
|
||||||
*/
|
*/
|
||||||
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
|
public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T> {
|
||||||
|
|
||||||
public DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
private DataStreamSource(StreamingContext streamingContext, SourceFunction<T> sourceFunction) {
|
||||||
super(streamingContext, new SourceOperator<>(sourceFunction));
|
super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
|
||||||
super.partition = new RoundRobinPartition<>();
|
}
|
||||||
|
|
||||||
|
public static <T> DataStreamSource<T> fromSource(
|
||||||
|
StreamingContext context, SourceFunction<T> sourceFunction) {
|
||||||
|
return new DataStreamSource<>(context, sourceFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -27,14 +31,9 @@ public class DataStreamSource<T> extends DataStream<T> implements StreamSource<T
|
||||||
* @param <T> The type of source data.
|
* @param <T> The type of source data.
|
||||||
* @return A DataStreamSource.
|
* @return A DataStreamSource.
|
||||||
*/
|
*/
|
||||||
public static <T> DataStreamSource<T> buildSource(
|
public static <T> DataStreamSource<T> fromCollection(
|
||||||
StreamingContext context, Collection<T> values) {
|
StreamingContext context, Collection<T> values) {
|
||||||
return new DataStreamSource(context, new CollectionSourceFunction(values));
|
return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataStreamSource<T> setParallelism(int parallelism) {
|
|
||||||
this.parallelism = parallelism;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
|
||||||
|
|
||||||
import io.ray.streaming.api.function.impl.AggregateFunction;
|
import io.ray.streaming.api.function.impl.AggregateFunction;
|
||||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||||
|
import io.ray.streaming.api.partition.Partition;
|
||||||
import io.ray.streaming.api.partition.impl.KeyPartition;
|
import io.ray.streaming.api.partition.impl.KeyPartition;
|
||||||
import io.ray.streaming.operator.StreamOperator;
|
import io.ray.streaming.operator.StreamOperator;
|
||||||
import io.ray.streaming.operator.impl.ReduceOperator;
|
import io.ray.streaming.operator.impl.ReduceOperator;
|
||||||
|
import io.ray.streaming.python.stream.PythonDataStream;
|
||||||
|
import io.ray.streaming.python.stream.PythonKeyDataStream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a DataStream returned by a key-by operation.
|
* Represents a DataStream returned by a key-by operation.
|
||||||
|
@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
|
||||||
* @param <K> Type of the key.
|
* @param <K> Type of the key.
|
||||||
* @param <T> Type of the data.
|
* @param <T> Type of the data.
|
||||||
*/
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public class KeyDataStream<K, T> extends DataStream<T> {
|
public class KeyDataStream<K, T> extends DataStream<T> {
|
||||||
|
|
||||||
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
|
public KeyDataStream(DataStream<T> input, StreamOperator streamOperator) {
|
||||||
super(input, streamOperator);
|
super(input, streamOperator, (Partition<T>) new KeyPartition<K, T>());
|
||||||
this.partition = new KeyPartition();
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a java stream that reference passed python stream.
|
||||||
|
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||||
|
*/
|
||||||
|
public KeyDataStream(PythonDataStream referencedStream) {
|
||||||
|
super(referencedStream);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -41,8 +52,13 @@ public class KeyDataStream<K, T> extends DataStream<T> {
|
||||||
return new DataStream<>(this, null);
|
return new DataStream<>(this, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public KeyDataStream<K, T> setParallelism(int parallelism) {
|
/**
|
||||||
this.parallelism = parallelism;
|
* Convert this stream as a python stream.
|
||||||
return this;
|
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||||
|
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||||
|
*/
|
||||||
|
public PythonKeyDataStream asPythonStream() {
|
||||||
|
return new PythonKeyDataStream(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,58 +1,99 @@
|
||||||
package io.ray.streaming.api.stream;
|
package io.ray.streaming.api.stream;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
import io.ray.streaming.api.partition.Partition;
|
import io.ray.streaming.api.partition.Partition;
|
||||||
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||||
|
import io.ray.streaming.operator.Operator;
|
||||||
import io.ray.streaming.operator.StreamOperator;
|
import io.ray.streaming.operator.StreamOperator;
|
||||||
import io.ray.streaming.python.PythonOperator;
|
|
||||||
import io.ray.streaming.python.PythonPartition;
|
import io.ray.streaming.python.PythonPartition;
|
||||||
import io.ray.streaming.python.stream.PythonStream;
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Abstract base class of all stream types.
|
* Abstract base class of all stream types.
|
||||||
*
|
*
|
||||||
|
* @param <S> Type of stream class
|
||||||
* @param <T> Type of the data in the stream.
|
* @param <T> Type of the data in the stream.
|
||||||
*/
|
*/
|
||||||
public abstract class Stream<T> implements Serializable {
|
public abstract class Stream<S extends Stream<S, T>, T>
|
||||||
protected int id;
|
implements Serializable {
|
||||||
protected int parallelism = 1;
|
private final int id;
|
||||||
protected StreamOperator operator;
|
private final StreamingContext streamingContext;
|
||||||
protected Stream<T> inputStream;
|
private final Stream inputStream;
|
||||||
protected StreamingContext streamingContext;
|
private final StreamOperator operator;
|
||||||
protected Partition<T> partition;
|
private int parallelism = 1;
|
||||||
|
private Partition<T> partition;
|
||||||
|
private Stream originalStream;
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
|
||||||
|
this(streamingContext, null, streamOperator,
|
||||||
|
selectPartition(streamOperator));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream(StreamingContext streamingContext,
|
||||||
|
StreamOperator streamOperator,
|
||||||
|
Partition<T> partition) {
|
||||||
|
this(streamingContext, null, streamOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream(Stream inputStream, StreamOperator streamOperator) {
|
||||||
|
this(inputStream.getStreamingContext(), inputStream, streamOperator,
|
||||||
|
selectPartition(streamOperator));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream(Stream inputStream, StreamOperator streamOperator, Partition<T> partition) {
|
||||||
|
this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Stream(StreamingContext streamingContext,
|
||||||
|
Stream inputStream,
|
||||||
|
StreamOperator streamOperator,
|
||||||
|
Partition<T> partition) {
|
||||||
this.streamingContext = streamingContext;
|
this.streamingContext = streamingContext;
|
||||||
|
this.inputStream = inputStream;
|
||||||
this.operator = streamOperator;
|
this.operator = streamOperator;
|
||||||
|
this.partition = partition;
|
||||||
this.id = streamingContext.generateId();
|
this.id = streamingContext.generateId();
|
||||||
if (streamOperator instanceof PythonOperator) {
|
if (inputStream != null) {
|
||||||
this.partition = PythonPartition.RoundRobinPartition;
|
this.parallelism = inputStream.getParallelism();
|
||||||
} else {
|
|
||||||
this.partition = new RoundRobinPartition<>();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Stream(Stream<T> inputStream, StreamOperator streamOperator) {
|
/**
|
||||||
this.inputStream = inputStream;
|
* Create a proxy stream of original stream.
|
||||||
this.parallelism = inputStream.getParallelism();
|
* Changes in new stream will be reflected in original stream and vice versa
|
||||||
this.streamingContext = this.inputStream.getStreamingContext();
|
*/
|
||||||
this.operator = streamOperator;
|
protected Stream(Stream originalStream) {
|
||||||
this.id = streamingContext.generateId();
|
this.originalStream = originalStream;
|
||||||
this.partition = selectPartition();
|
this.id = originalStream.getId();
|
||||||
|
this.streamingContext = originalStream.getStreamingContext();
|
||||||
|
this.inputStream = originalStream.getInputStream();
|
||||||
|
this.operator = originalStream.getOperator();
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private Partition<T> selectPartition() {
|
private static <T> Partition<T> selectPartition(Operator operator) {
|
||||||
if (inputStream instanceof PythonStream) {
|
switch (operator.getLanguage()) {
|
||||||
return PythonPartition.RoundRobinPartition;
|
case PYTHON:
|
||||||
} else {
|
return (Partition<T>) PythonPartition.RoundRobinPartition;
|
||||||
|
case JAVA:
|
||||||
return new RoundRobinPartition<>();
|
return new RoundRobinPartition<>();
|
||||||
|
default:
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Unsupported language " + operator.getLanguage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Stream<T> getInputStream() {
|
public int getId() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public StreamingContext getStreamingContext() {
|
||||||
|
return streamingContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream getInputStream() {
|
||||||
return inputStream;
|
return inputStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,32 +101,47 @@ public abstract class Stream<T> implements Serializable {
|
||||||
return operator;
|
return operator;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setOperator(StreamOperator operator) {
|
@SuppressWarnings("unchecked")
|
||||||
this.operator = operator;
|
private S self() {
|
||||||
}
|
return (S) this;
|
||||||
|
|
||||||
public StreamingContext getStreamingContext() {
|
|
||||||
return streamingContext;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getParallelism() {
|
public int getParallelism() {
|
||||||
return parallelism;
|
return originalStream != null ? originalStream.getParallelism() : parallelism;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Stream<T> setParallelism(int parallelism) {
|
public S setParallelism(int parallelism) {
|
||||||
|
if (originalStream != null) {
|
||||||
|
originalStream.setParallelism(parallelism);
|
||||||
|
} else {
|
||||||
this.parallelism = parallelism;
|
this.parallelism = parallelism;
|
||||||
return this;
|
}
|
||||||
}
|
return self();
|
||||||
|
|
||||||
public int getId() {
|
|
||||||
return id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public Partition<T> getPartition() {
|
public Partition<T> getPartition() {
|
||||||
return partition;
|
return originalStream != null ? originalStream.getPartition() : partition;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setPartition(Partition<T> partition) {
|
@SuppressWarnings("unchecked")
|
||||||
|
protected S setPartition(Partition<T> partition) {
|
||||||
|
if (originalStream != null) {
|
||||||
|
originalStream.setPartition(partition);
|
||||||
|
} else {
|
||||||
this.partition = partition;
|
this.partition = partition;
|
||||||
}
|
}
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isProxyStream() {
|
||||||
|
return originalStream != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream getOriginalStream() {
|
||||||
|
Preconditions.checkArgument(isProxyStream());
|
||||||
|
return originalStream;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract Language getLanguage();
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
|
||||||
*
|
*
|
||||||
* @param <T> Type of the input data of this sink.
|
* @param <T> Type of the input data of this sink.
|
||||||
*/
|
*/
|
||||||
public class StreamSink<T> extends Stream<T> {
|
public abstract class StreamSink<T> extends Stream<StreamSink<T>, T> {
|
||||||
public StreamSink(Stream<T> inputStream, StreamOperator streamOperator) {
|
public StreamSink(Stream inputStream, StreamOperator streamOperator) {
|
||||||
super(inputStream, streamOperator);
|
super(inputStream, streamOperator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,15 +11,15 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class UnionStream<T> extends DataStream<T> {
|
public class UnionStream<T> extends DataStream<T> {
|
||||||
|
|
||||||
private List<DataStream> unionStreams;
|
private List<DataStream<T>> unionStreams;
|
||||||
|
|
||||||
public UnionStream(DataStream input, StreamOperator streamOperator, DataStream<T> other) {
|
public UnionStream(DataStream<T> input, StreamOperator streamOperator, DataStream<T> other) {
|
||||||
super(input, streamOperator);
|
super(input, streamOperator);
|
||||||
this.unionStreams = new ArrayList<>();
|
this.unionStreams = new ArrayList<>();
|
||||||
this.unionStreams.add(other);
|
this.unionStreams.add(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<DataStream> getUnionStreams() {
|
public List<DataStream<T>> getUnionStreams() {
|
||||||
return unionStreams;
|
return unionStreams;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package io.ray.streaming.jobgraph;
|
package io.ray.streaming.jobgraph;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isCrossLanguageGraph() {
|
||||||
|
Language language = jobVertexList.get(0).getLanguage();
|
||||||
|
for (JobVertex jobVertex : jobVertexList) {
|
||||||
|
if (jobVertex.getLanguage() != language) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package io.ray.streaming.jobgraph;
|
package io.ray.streaming.jobgraph;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
import io.ray.streaming.api.stream.DataStream;
|
import io.ray.streaming.api.stream.DataStream;
|
||||||
import io.ray.streaming.api.stream.Stream;
|
import io.ray.streaming.api.stream.Stream;
|
||||||
import io.ray.streaming.api.stream.StreamSink;
|
import io.ray.streaming.api.stream.StreamSink;
|
||||||
|
@ -10,8 +11,11 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
public class JobGraphBuilder {
|
public class JobGraphBuilder {
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
|
||||||
|
|
||||||
private JobGraph jobGraph;
|
private JobGraph jobGraph;
|
||||||
|
|
||||||
|
@ -41,12 +45,19 @@ public class JobGraphBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processStream(Stream stream) {
|
private void processStream(Stream stream) {
|
||||||
|
while (stream.isProxyStream()) {
|
||||||
|
// Proxy stream and original stream are the same logical stream, both refer to the
|
||||||
|
// same data flow transformation. We should skip proxy stream to avoid applying same
|
||||||
|
// transformation multiple times.
|
||||||
|
LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
|
||||||
|
stream = stream.getOriginalStream();
|
||||||
|
}
|
||||||
|
StreamOperator streamOperator = stream.getOperator();
|
||||||
|
Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
|
||||||
|
"Reference stream should be skipped.");
|
||||||
int vertexId = stream.getId();
|
int vertexId = stream.getId();
|
||||||
int parallelism = stream.getParallelism();
|
int parallelism = stream.getParallelism();
|
||||||
|
JobVertex jobVertex;
|
||||||
StreamOperator streamOperator = stream.getOperator();
|
|
||||||
JobVertex jobVertex = null;
|
|
||||||
|
|
||||||
if (stream instanceof StreamSink) {
|
if (stream instanceof StreamSink) {
|
||||||
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
|
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
|
||||||
Stream parentStream = stream.getInputStream();
|
Stream parentStream = stream.getInputStream();
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package io.ray.streaming.message;
|
package io.ray.streaming.message;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
public class KeyRecord<K, T> extends Record<T> {
|
public class KeyRecord<K, T> extends Record<T> {
|
||||||
|
|
||||||
private K key;
|
private K key;
|
||||||
|
@ -17,4 +19,24 @@ public class KeyRecord<K, T> extends Record<T> {
|
||||||
public void setKey(K key) {
|
public void setKey(K key) {
|
||||||
this.key = key;
|
this.key = key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!super.equals(o)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
KeyRecord<?, ?> keyRecord = (KeyRecord<?, ?>) o;
|
||||||
|
return Objects.equals(key, keyRecord.key);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(super.hashCode(), key);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
package io.ray.streaming.message;
|
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class Message implements Serializable {
|
|
||||||
|
|
||||||
private int taskId;
|
|
||||||
private long batchId;
|
|
||||||
private String stream;
|
|
||||||
private List<Record> recordList;
|
|
||||||
|
|
||||||
public Message(int taskId, long batchId, String stream, List<Record> recordList) {
|
|
||||||
this.taskId = taskId;
|
|
||||||
this.batchId = batchId;
|
|
||||||
this.stream = stream;
|
|
||||||
this.recordList = recordList;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Message(int taskId, long batchId, String stream, Record record) {
|
|
||||||
this.taskId = taskId;
|
|
||||||
this.batchId = batchId;
|
|
||||||
this.stream = stream;
|
|
||||||
this.recordList = Lists.newArrayList(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getTaskId() {
|
|
||||||
return taskId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTaskId(int taskId) {
|
|
||||||
this.taskId = taskId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public long getBatchId() {
|
|
||||||
return batchId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setBatchId(long batchId) {
|
|
||||||
this.batchId = batchId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getStream() {
|
|
||||||
return stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setStream(String stream) {
|
|
||||||
this.stream = stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Record> getRecordList() {
|
|
||||||
return recordList;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setRecordList(List<Record> recordList) {
|
|
||||||
this.recordList = recordList;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Record getRecord(int index) {
|
|
||||||
return recordList.get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
package io.ray.streaming.message;
|
package io.ray.streaming.message;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
|
||||||
public class Record<T> implements Serializable {
|
public class Record<T> implements Serializable {
|
||||||
|
@ -27,6 +28,24 @@ public class Record<T> implements Serializable {
|
||||||
this.stream = stream;
|
this.stream = stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (o == null || getClass() != o.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Record<?> record = (Record<?>) o;
|
||||||
|
return Objects.equals(stream, record.stream) &&
|
||||||
|
Objects.equals(value, record.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(stream, value);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return value.toString();
|
return value.toString();
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package io.ray.streaming.python;
|
package io.ray.streaming.python;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
import io.ray.streaming.api.function.Function;
|
import io.ray.streaming.api.function.Function;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a user defined python function.
|
* Represents a user defined python function.
|
||||||
|
@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
|
||||||
*
|
*
|
||||||
* <p>If the python data stream api is invoked from python, `function` will be not null.</p>
|
* <p>If the python data stream api is invoked from python, `function` will be not null.</p>
|
||||||
* <p>If the python data stream api is invoked from java, `moduleName` and
|
* <p>If the python data stream api is invoked from java, `moduleName` and
|
||||||
* `className`/`functionName` will be not null.</p>
|
* `functionName` will be not null.</p>
|
||||||
* <p>
|
* <p>
|
||||||
* TODO serialize to bytes using protobuf
|
|
||||||
*/
|
*/
|
||||||
public class PythonFunction implements Function {
|
public class PythonFunction implements Function {
|
||||||
public enum FunctionInterface {
|
public enum FunctionInterface {
|
||||||
|
@ -38,23 +39,43 @@ public class PythonFunction implements Function {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private byte[] function;
|
// null if this function is constructed from moduleName/functionName.
|
||||||
private String moduleName;
|
private final byte[] function;
|
||||||
private String className;
|
// null if this function is constructed from serialized python function.
|
||||||
private String functionName;
|
private final String moduleName;
|
||||||
|
// null if this function is constructed from serialized python function.
|
||||||
|
private final String functionName;
|
||||||
/**
|
/**
|
||||||
* FunctionInterface can be used to validate python function,
|
* FunctionInterface can be used to validate python function,
|
||||||
* and look up operator class from FunctionInterface.
|
* and look up operator class from FunctionInterface.
|
||||||
*/
|
*/
|
||||||
private String functionInterface;
|
private String functionInterface;
|
||||||
|
|
||||||
private PythonFunction(byte[] function,
|
/**
|
||||||
String moduleName,
|
* Create a {@link PythonFunction} from a serialized streaming python function.
|
||||||
String className,
|
*
|
||||||
String functionName) {
|
* @param function serialized streaming python function from python driver.
|
||||||
|
*/
|
||||||
|
public PythonFunction(byte[] function) {
|
||||||
|
Preconditions.checkNotNull(function);
|
||||||
this.function = function;
|
this.function = function;
|
||||||
|
this.moduleName = null;
|
||||||
|
this.functionName = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link PythonFunction} from a moduleName and streaming function name.
|
||||||
|
*
|
||||||
|
* @param moduleName module name of streaming function.
|
||||||
|
* @param functionName function name of streaming function. {@code functionName} is the name
|
||||||
|
* of a python function, or class name of subclass of `ray.streaming.function.`
|
||||||
|
*/
|
||||||
|
public PythonFunction(String moduleName,
|
||||||
|
String functionName) {
|
||||||
|
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
|
||||||
|
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
|
||||||
|
this.function = null;
|
||||||
this.moduleName = moduleName;
|
this.moduleName = moduleName;
|
||||||
this.className = className;
|
|
||||||
this.functionName = functionName;
|
this.functionName = functionName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,10 +91,6 @@ public class PythonFunction implements Function {
|
||||||
return moduleName;
|
return moduleName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getClassName() {
|
|
||||||
return className;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getFunctionName() {
|
public String getFunctionName() {
|
||||||
return functionName;
|
return functionName;
|
||||||
}
|
}
|
||||||
|
@ -82,34 +99,4 @@ public class PythonFunction implements Function {
|
||||||
return functionInterface;
|
return functionInterface;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a {@link PythonFunction} using python serialized function
|
|
||||||
*
|
|
||||||
* @param function serialized python function sent from python driver
|
|
||||||
*/
|
|
||||||
public static PythonFunction fromFunction(byte[] function) {
|
|
||||||
return new PythonFunction(function, null, null, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a {@link PythonFunction} using <code>moduleName</code> and
|
|
||||||
* <code>className</code>.
|
|
||||||
*
|
|
||||||
* @param moduleName python module name
|
|
||||||
* @param className python class name
|
|
||||||
*/
|
|
||||||
public static PythonFunction fromClassName(String moduleName, String className) {
|
|
||||||
return new PythonFunction(null, moduleName, className, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a {@link PythonFunction} using <code>moduleName</code> and
|
|
||||||
* <code>functionName</code>.
|
|
||||||
*
|
|
||||||
* @param moduleName python module name
|
|
||||||
* @param functionName python function name
|
|
||||||
*/
|
|
||||||
public static PythonFunction fromFunctionName(String moduleName, String functionName) {
|
|
||||||
return new PythonFunction(null, moduleName, null, functionName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package io.ray.streaming.python;
|
package io.ray.streaming.python;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
import io.ray.streaming.api.partition.Partition;
|
import io.ray.streaming.api.partition.Partition;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents a python partition function.
|
* Represents a python partition function.
|
||||||
|
@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
|
||||||
* If this object is constructed from moduleName and className/functionName,
|
* If this object is constructed from moduleName and className/functionName,
|
||||||
* python worker will use `importlib` to load python partition function.
|
* python worker will use `importlib` to load python partition function.
|
||||||
* <p>
|
* <p>
|
||||||
* TODO serialize to bytes using protobuf
|
|
||||||
*/
|
*/
|
||||||
public class PythonPartition implements Partition {
|
public class PythonPartition implements Partition<Object> {
|
||||||
public static final PythonPartition BroadcastPartition = new PythonPartition(
|
public static final PythonPartition BroadcastPartition = new PythonPartition(
|
||||||
"ray.streaming.partition", "BroadcastPartition", null);
|
"ray.streaming.partition", "BroadcastPartition");
|
||||||
public static final PythonPartition KeyPartition = new PythonPartition(
|
public static final PythonPartition KeyPartition = new PythonPartition(
|
||||||
"ray.streaming.partition", "KeyPartition", null);
|
"ray.streaming.partition", "KeyPartition");
|
||||||
public static final PythonPartition RoundRobinPartition = new PythonPartition(
|
public static final PythonPartition RoundRobinPartition = new PythonPartition(
|
||||||
"ray.streaming.partition", "RoundRobinPartition", null);
|
"ray.streaming.partition", "RoundRobinPartition");
|
||||||
|
|
||||||
private byte[] partition;
|
private byte[] partition;
|
||||||
private String moduleName;
|
private String moduleName;
|
||||||
private String className;
|
|
||||||
private String functionName;
|
private String functionName;
|
||||||
|
|
||||||
public PythonPartition(byte[] partition) {
|
public PythonPartition(byte[] partition) {
|
||||||
|
Preconditions.checkNotNull(partition);
|
||||||
this.partition = partition;
|
this.partition = partition;
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonPartition(String moduleName, String className, String functionName) {
|
/**
|
||||||
|
* Create a python partition from a moduleName and partition function name
|
||||||
|
* @param moduleName module name of python partition
|
||||||
|
* @param functionName function/class name of the partition function.
|
||||||
|
*/
|
||||||
|
public PythonPartition(String moduleName, String functionName) {
|
||||||
|
Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
|
||||||
|
Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
|
||||||
this.moduleName = moduleName;
|
this.moduleName = moduleName;
|
||||||
this.className = className;
|
|
||||||
this.functionName = functionName;
|
this.functionName = functionName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,10 +60,6 @@ public class PythonPartition implements Partition {
|
||||||
return moduleName;
|
return moduleName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getClassName() {
|
|
||||||
return className;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getFunctionName() {
|
public String getFunctionName() {
|
||||||
return functionName;
|
return functionName;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package io.ray.streaming.python.stream;
|
package io.ray.streaming.python.stream;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
|
import io.ray.streaming.api.partition.Partition;
|
||||||
|
import io.ray.streaming.api.stream.DataStream;
|
||||||
import io.ray.streaming.api.stream.Stream;
|
import io.ray.streaming.api.stream.Stream;
|
||||||
import io.ray.streaming.python.PythonFunction;
|
import io.ray.streaming.python.PythonFunction;
|
||||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||||
|
@ -10,19 +13,39 @@ import io.ray.streaming.python.PythonPartition;
|
||||||
/**
|
/**
|
||||||
* Represents a stream of data whose transformations will be executed in python.
|
* Represents a stream of data whose transformations will be executed in python.
|
||||||
*/
|
*/
|
||||||
public class PythonDataStream extends Stream implements PythonStream {
|
public class PythonDataStream extends Stream<PythonDataStream, Object> implements PythonStream {
|
||||||
|
|
||||||
protected PythonDataStream(StreamingContext streamingContext,
|
protected PythonDataStream(StreamingContext streamingContext,
|
||||||
PythonOperator pythonOperator) {
|
PythonOperator pythonOperator) {
|
||||||
super(streamingContext, pythonOperator);
|
super(streamingContext, pythonOperator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected PythonDataStream(StreamingContext streamingContext,
|
||||||
|
PythonOperator pythonOperator,
|
||||||
|
Partition<Object> partition) {
|
||||||
|
super(streamingContext, pythonOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
public PythonDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
||||||
super(input, pythonOperator);
|
super(input, pythonOperator);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected PythonDataStream(Stream inputStream, PythonOperator pythonOperator) {
|
public PythonDataStream(PythonDataStream input,
|
||||||
super(inputStream, pythonOperator);
|
PythonOperator pythonOperator,
|
||||||
|
Partition<Object> partition) {
|
||||||
|
super(input, pythonOperator, partition);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a python stream that reference passed java stream.
|
||||||
|
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||||
|
*/
|
||||||
|
public PythonDataStream(DataStream referencedStream) {
|
||||||
|
super(referencedStream);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonDataStream map(String moduleName, String funcName) {
|
||||||
|
return map(new PythonFunction(moduleName, funcName));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,6 +59,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
return new PythonDataStream(this, new PythonOperator(func));
|
return new PythonDataStream(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public PythonDataStream flatMap(String moduleName, String funcName) {
|
||||||
|
return flatMap(new PythonFunction(moduleName, funcName));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a flat-map function to this stream.
|
* Apply a flat-map function to this stream.
|
||||||
*
|
*
|
||||||
|
@ -47,6 +74,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
return new PythonDataStream(this, new PythonOperator(func));
|
return new PythonDataStream(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public PythonDataStream filter(String moduleName, String funcName) {
|
||||||
|
return filter(new PythonFunction(moduleName, funcName));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a filter function to this stream.
|
* Apply a filter function to this stream.
|
||||||
*
|
*
|
||||||
|
@ -59,6 +90,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
return new PythonDataStream(this, new PythonOperator(func));
|
return new PythonDataStream(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public PythonStreamSink sink(String moduleName, String funcName) {
|
||||||
|
return sink(new PythonFunction(moduleName, funcName));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a sink function and get a StreamSink.
|
* Apply a sink function and get a StreamSink.
|
||||||
*
|
*
|
||||||
|
@ -70,6 +105,10 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
return new PythonStreamSink(this, new PythonOperator(func));
|
return new PythonStreamSink(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public PythonKeyDataStream keyBy(String moduleName, String funcName) {
|
||||||
|
return keyBy(new PythonFunction(moduleName, funcName));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a key-by function to this stream.
|
* Apply a key-by function to this stream.
|
||||||
*
|
*
|
||||||
|
@ -77,6 +116,7 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
* @return A new KeyDataStream.
|
* @return A new KeyDataStream.
|
||||||
*/
|
*/
|
||||||
public PythonKeyDataStream keyBy(PythonFunction func) {
|
public PythonKeyDataStream keyBy(PythonFunction func) {
|
||||||
|
checkPartitionCall();
|
||||||
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
|
func.setFunctionInterface(FunctionInterface.KEY_FUNCTION);
|
||||||
return new PythonKeyDataStream(this, new PythonOperator(func));
|
return new PythonKeyDataStream(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
@ -87,8 +127,8 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
* @return This stream.
|
* @return This stream.
|
||||||
*/
|
*/
|
||||||
public PythonDataStream broadcast() {
|
public PythonDataStream broadcast() {
|
||||||
this.partition = PythonPartition.BroadcastPartition;
|
checkPartitionCall();
|
||||||
return this;
|
return setPartition(PythonPartition.BroadcastPartition);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,19 +138,33 @@ public class PythonDataStream extends Stream implements PythonStream {
|
||||||
* @return This stream.
|
* @return This stream.
|
||||||
*/
|
*/
|
||||||
public PythonDataStream partitionBy(PythonPartition partition) {
|
public PythonDataStream partitionBy(PythonPartition partition) {
|
||||||
this.partition = partition;
|
checkPartitionCall();
|
||||||
return this;
|
return setPartition(partition);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set parallelism to current transformation.
|
* If parent stream is a python stream, we can't call partition related methods
|
||||||
*
|
* in the java stream.
|
||||||
* @param parallelism The parallelism to set.
|
|
||||||
* @return This stream.
|
|
||||||
*/
|
*/
|
||||||
public PythonDataStream setParallelism(int parallelism) {
|
private void checkPartitionCall() {
|
||||||
this.parallelism = parallelism;
|
if (getInputStream() != null && getInputStream().getLanguage() == Language.JAVA) {
|
||||||
return this;
|
throw new RuntimeException("Partition related methods can't be called on a " +
|
||||||
|
"python stream if parent stream is a java stream.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert this stream as a java stream.
|
||||||
|
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||||
|
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||||
|
*/
|
||||||
|
public DataStream<Object> asJavaStream() {
|
||||||
|
return new DataStream<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Language getLanguage() {
|
||||||
|
return Language.PYTHON;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package io.ray.streaming.python.stream;
|
package io.ray.streaming.python.stream;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.stream.DataStream;
|
||||||
|
import io.ray.streaming.api.stream.KeyDataStream;
|
||||||
import io.ray.streaming.python.PythonFunction;
|
import io.ray.streaming.python.PythonFunction;
|
||||||
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
import io.ray.streaming.python.PythonFunction.FunctionInterface;
|
||||||
import io.ray.streaming.python.PythonOperator;
|
import io.ray.streaming.python.PythonOperator;
|
||||||
|
@ -8,11 +10,23 @@ import io.ray.streaming.python.PythonPartition;
|
||||||
/**
|
/**
|
||||||
* Represents a python DataStream returned by a key-by operation.
|
* Represents a python DataStream returned by a key-by operation.
|
||||||
*/
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
|
public class PythonKeyDataStream extends PythonDataStream implements PythonStream {
|
||||||
|
|
||||||
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
public PythonKeyDataStream(PythonDataStream input, PythonOperator pythonOperator) {
|
||||||
super(input, pythonOperator);
|
super(input, pythonOperator, PythonPartition.KeyPartition);
|
||||||
this.partition = PythonPartition.KeyPartition;
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a python stream that reference passed python stream.
|
||||||
|
* Changes in new stream will be reflected in referenced stream and vice versa
|
||||||
|
*/
|
||||||
|
public PythonKeyDataStream(DataStream referencedStream) {
|
||||||
|
super(referencedStream);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonDataStream reduce(String moduleName, String funcName) {
|
||||||
|
return reduce(new PythonFunction(moduleName, funcName));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -26,9 +40,13 @@ public class PythonKeyDataStream extends PythonDataStream implements PythonStrea
|
||||||
return new PythonDataStream(this, new PythonOperator(func));
|
return new PythonDataStream(this, new PythonOperator(func));
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonKeyDataStream setParallelism(int parallelism) {
|
/**
|
||||||
this.parallelism = parallelism;
|
* Convert this stream as a java stream.
|
||||||
return this;
|
* The converted stream and this stream are the same logical stream, which has same stream id.
|
||||||
|
* Changes in converted stream will be reflected in this stream and vice versa.
|
||||||
|
*/
|
||||||
|
public KeyDataStream<Object, Object> asJavaStream() {
|
||||||
|
return new KeyDataStream(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package io.ray.streaming.python.stream;
|
package io.ray.streaming.python.stream;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.api.stream.StreamSink;
|
import io.ray.streaming.api.stream.StreamSink;
|
||||||
import io.ray.streaming.python.PythonOperator;
|
import io.ray.streaming.python.PythonOperator;
|
||||||
|
|
||||||
|
@ -9,12 +10,12 @@ import io.ray.streaming.python.PythonOperator;
|
||||||
public class PythonStreamSink extends StreamSink implements PythonStream {
|
public class PythonStreamSink extends StreamSink implements PythonStream {
|
||||||
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
|
public PythonStreamSink(PythonDataStream input, PythonOperator sinkOperator) {
|
||||||
super(input, sinkOperator);
|
super(input, sinkOperator);
|
||||||
this.streamingContext.addSink(this);
|
getStreamingContext().addSink(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
public PythonStreamSink setParallelism(int parallelism) {
|
@Override
|
||||||
this.parallelism = parallelism;
|
public Language getLanguage() {
|
||||||
return this;
|
return Language.PYTHON;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,13 +13,8 @@ import io.ray.streaming.python.PythonPartition;
|
||||||
public class PythonStreamSource extends PythonDataStream implements StreamSource {
|
public class PythonStreamSource extends PythonDataStream implements StreamSource {
|
||||||
|
|
||||||
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
|
private PythonStreamSource(StreamingContext streamingContext, PythonFunction sourceFunction) {
|
||||||
super(streamingContext, new PythonOperator(sourceFunction));
|
super(streamingContext, new PythonOperator(sourceFunction),
|
||||||
super.partition = PythonPartition.RoundRobinPartition;
|
PythonPartition.RoundRobinPartition);
|
||||||
}
|
|
||||||
|
|
||||||
public PythonStreamSource setParallelism(int parallelism) {
|
|
||||||
this.parallelism = parallelism;
|
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PythonStreamSource from(StreamingContext streamingContext,
|
public static PythonStreamSource from(StreamingContext streamingContext,
|
||||||
|
|
|
@ -21,7 +21,6 @@ public class Config {
|
||||||
public static final String CHANNEL_TYPE = "channel_type";
|
public static final String CHANNEL_TYPE = "channel_type";
|
||||||
public static final String MEMORY_CHANNEL = "memory_channel";
|
public static final String MEMORY_CHANNEL = "memory_channel";
|
||||||
public static final String NATIVE_CHANNEL = "native_channel";
|
public static final String NATIVE_CHANNEL = "native_channel";
|
||||||
public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL;
|
|
||||||
public static final String CHANNEL_SIZE = "channel_size";
|
public static final String CHANNEL_SIZE = "channel_size";
|
||||||
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
|
public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8));
|
||||||
public static final String IS_RECREATE = "streaming.is_recreate";
|
public static final String IS_RECREATE = "streaming.is_recreate";
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
package io.ray.streaming.api.stream;
|
||||||
|
|
||||||
|
import static org.testng.Assert.assertEquals;
|
||||||
|
|
||||||
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
|
import io.ray.streaming.operator.impl.MapOperator;
|
||||||
|
import io.ray.streaming.python.stream.PythonDataStream;
|
||||||
|
import io.ray.streaming.python.stream.PythonKeyDataStream;
|
||||||
|
import org.testng.annotations.Test;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public class StreamTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReferencedDataStream() {
|
||||||
|
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
|
||||||
|
new MapOperator(value -> null));
|
||||||
|
PythonDataStream pythonDataStream = dataStream.asPythonStream();
|
||||||
|
DataStream javaStream = pythonDataStream.asJavaStream();
|
||||||
|
assertEquals(dataStream.getId(), pythonDataStream.getId());
|
||||||
|
assertEquals(dataStream.getId(), javaStream.getId());
|
||||||
|
javaStream.setParallelism(10);
|
||||||
|
assertEquals(dataStream.getParallelism(), pythonDataStream.getParallelism());
|
||||||
|
assertEquals(dataStream.getParallelism(), javaStream.getParallelism());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReferencedKeyDataStream() {
|
||||||
|
DataStream dataStream = new DataStream(StreamingContext.buildContext(),
|
||||||
|
new MapOperator(value -> null));
|
||||||
|
KeyDataStream keyDataStream = dataStream.keyBy(value -> null);
|
||||||
|
PythonKeyDataStream pythonKeyDataStream = keyDataStream.asPythonStream();
|
||||||
|
KeyDataStream javaKeyDataStream = pythonKeyDataStream.asJavaStream();
|
||||||
|
assertEquals(keyDataStream.getId(), pythonKeyDataStream.getId());
|
||||||
|
assertEquals(keyDataStream.getId(), javaKeyDataStream.getId());
|
||||||
|
javaKeyDataStream.setParallelism(10);
|
||||||
|
assertEquals(keyDataStream.getParallelism(), pythonKeyDataStream.getParallelism());
|
||||||
|
assertEquals(keyDataStream.getParallelism(), javaKeyDataStream.getParallelism());
|
||||||
|
}
|
||||||
|
}
|
|
@ -38,7 +38,7 @@ public class JobGraphBuilderTest {
|
||||||
|
|
||||||
public JobGraph buildDataSyncJobGraph() {
|
public JobGraph buildDataSyncJobGraph() {
|
||||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||||
Lists.newArrayList("a", "b", "c"));
|
Lists.newArrayList("a", "b", "c"));
|
||||||
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
||||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
||||||
|
@ -73,7 +73,7 @@ public class JobGraphBuilderTest {
|
||||||
|
|
||||||
public JobGraph buildKeyByJobGraph() {
|
public JobGraph buildKeyByJobGraph() {
|
||||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||||
Lists.newArrayList("1", "2", "3", "4"));
|
Lists.newArrayList("1", "2", "3", "4"));
|
||||||
StreamSink streamSink = dataStream.keyBy(x -> x)
|
StreamSink streamSink = dataStream.keyBy(x -> x)
|
||||||
.sink(x -> LOG.info(x));
|
.sink(x -> LOG.info(x));
|
||||||
|
|
|
@ -36,6 +36,11 @@
|
||||||
<artifactId>flatbuffers-java</artifactId>
|
<artifactId>flatbuffers-java</artifactId>
|
||||||
<version>1.9.0.1</version>
|
<version>1.9.0.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
<artifactId>jsr305</artifactId>
|
||||||
|
<version>3.0.2</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.google.guava</groupId>
|
<groupId>com.google.guava</groupId>
|
||||||
<artifactId>guava</artifactId>
|
<artifactId>guava</artifactId>
|
||||||
|
@ -56,6 +61,11 @@
|
||||||
<artifactId>owner</artifactId>
|
<artifactId>owner</artifactId>
|
||||||
<version>1.0.10</version>
|
<version>1.0.10</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>3.4</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.mockito</groupId>
|
<groupId>org.mockito</groupId>
|
||||||
<artifactId>mockito-all</artifactId>
|
<artifactId>mockito-all</artifactId>
|
||||||
|
@ -71,11 +81,6 @@
|
||||||
<artifactId>powermock-api-mockito</artifactId>
|
<artifactId>powermock-api-mockito</artifactId>
|
||||||
<version>1.6.6</version>
|
<version>1.6.6</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.powermock</groupId>
|
|
||||||
<artifactId>powermock-core</artifactId>
|
|
||||||
<version>1.6.6</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.powermock</groupId>
|
<groupId>org.powermock</groupId>
|
||||||
<artifactId>powermock-module-testng</artifactId>
|
<artifactId>powermock-module-testng</artifactId>
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package io.ray.streaming.runtime.core.collector;
|
package io.ray.streaming.runtime.core.collector;
|
||||||
|
|
||||||
import io.ray.runtime.serializer.Serializer;
|
import io.ray.api.BaseActor;
|
||||||
|
import io.ray.api.RayPyActor;
|
||||||
|
import io.ray.streaming.api.Language;
|
||||||
import io.ray.streaming.api.collector.Collector;
|
import io.ray.streaming.api.collector.Collector;
|
||||||
import io.ray.streaming.api.partition.Partition;
|
import io.ray.streaming.api.partition.Partition;
|
||||||
import io.ray.streaming.message.Record;
|
import io.ray.streaming.message.Record;
|
||||||
|
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
|
||||||
|
import io.ray.streaming.runtime.serialization.JavaSerializer;
|
||||||
|
import io.ray.streaming.runtime.serialization.Serializer;
|
||||||
import io.ray.streaming.runtime.transfer.ChannelID;
|
import io.ray.streaming.runtime.transfer.ChannelID;
|
||||||
import io.ray.streaming.runtime.transfer.DataWriter;
|
import io.ray.streaming.runtime.transfer.DataWriter;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory;
|
||||||
public class OutputCollector implements Collector<Record> {
|
public class OutputCollector implements Collector<Record> {
|
||||||
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
|
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
|
||||||
|
|
||||||
private Partition partition;
|
private final DataWriter writer;
|
||||||
private DataWriter writer;
|
private final ChannelID[] outputQueues;
|
||||||
private ChannelID[] outputQueues;
|
private final Collection<BaseActor> targetActors;
|
||||||
|
private final Language[] targetLanguages;
|
||||||
|
private final Partition partition;
|
||||||
|
private final Serializer javaSerializer = new JavaSerializer();
|
||||||
|
private final Serializer crossLangSerializer = new CrossLangSerializer();
|
||||||
|
|
||||||
public OutputCollector(Collection<String> outputQueueIds,
|
public OutputCollector(DataWriter writer,
|
||||||
DataWriter writer,
|
Collection<String> outputQueueIds,
|
||||||
|
Collection<BaseActor> targetActors,
|
||||||
Partition partition) {
|
Partition partition) {
|
||||||
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
|
|
||||||
this.writer = writer;
|
this.writer = writer;
|
||||||
|
this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
|
||||||
|
this.targetActors = targetActors;
|
||||||
|
this.targetLanguages = targetActors.stream()
|
||||||
|
.map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA)
|
||||||
|
.toArray(Language[]::new);
|
||||||
this.partition = partition;
|
this.partition = partition;
|
||||||
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
|
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
|
||||||
outputQueueIds, this.partition);
|
outputQueueIds, this.partition);
|
||||||
|
@ -31,9 +45,32 @@ public class OutputCollector implements Collector<Record> {
|
||||||
@Override
|
@Override
|
||||||
public void collect(Record record) {
|
public void collect(Record record) {
|
||||||
int[] partitions = this.partition.partition(record, outputQueues.length);
|
int[] partitions = this.partition.partition(record, outputQueues.length);
|
||||||
ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft());
|
ByteBuffer javaBuffer = null;
|
||||||
|
ByteBuffer crossLangBuffer = null;
|
||||||
for (int partition : partitions) {
|
for (int partition : partitions) {
|
||||||
writer.write(outputQueues[partition], msgBuffer);
|
if (targetLanguages[partition] == Language.JAVA) {
|
||||||
|
// avoid repeated serialization
|
||||||
|
if (javaBuffer == null) {
|
||||||
|
byte[] bytes = javaSerializer.serialize(record);
|
||||||
|
javaBuffer = ByteBuffer.allocate(1 + bytes.length);
|
||||||
|
javaBuffer.put(Serializer.JAVA_TYPE_ID);
|
||||||
|
// TODO(chaokunyang) remove copy
|
||||||
|
javaBuffer.put(bytes);
|
||||||
|
javaBuffer.flip();
|
||||||
|
}
|
||||||
|
writer.write(outputQueues[partition], javaBuffer.duplicate());
|
||||||
|
} else {
|
||||||
|
// avoid repeated serialization
|
||||||
|
if (crossLangBuffer == null) {
|
||||||
|
byte[] bytes = crossLangSerializer.serialize(record);
|
||||||
|
crossLangBuffer = ByteBuffer.allocate(1 + bytes.length);
|
||||||
|
crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID);
|
||||||
|
// TODO(chaokunyang) remove copy
|
||||||
|
crossLangBuffer.put(bytes);
|
||||||
|
crossLangBuffer.flip();
|
||||||
|
}
|
||||||
|
writer.write(outputQueues[partition], crossLangBuffer.duplicate());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||||
import io.ray.streaming.runtime.core.graph.ExecutionTask;
|
import io.ray.streaming.runtime.core.graph.ExecutionTask;
|
||||||
import io.ray.streaming.runtime.generated.RemoteCall;
|
import io.ray.streaming.runtime.generated.RemoteCall;
|
||||||
import io.ray.streaming.runtime.generated.Streaming;
|
import io.ray.streaming.runtime.generated.Streaming;
|
||||||
|
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
public class GraphPbBuilder {
|
public class GraphPbBuilder {
|
||||||
|
@ -74,11 +75,10 @@ public class GraphPbBuilder {
|
||||||
private byte[] serializeFunction(Function function) {
|
private byte[] serializeFunction(Function function) {
|
||||||
if (function instanceof PythonFunction) {
|
if (function instanceof PythonFunction) {
|
||||||
PythonFunction pyFunc = (PythonFunction) function;
|
PythonFunction pyFunc = (PythonFunction) function;
|
||||||
// function_bytes, module_name, class_name, function_name, function_interface
|
// function_bytes, module_name, function_name, function_interface
|
||||||
return serializer.serialize(Arrays.asList(
|
return serializer.serialize(Arrays.asList(
|
||||||
pyFunc.getFunction(), pyFunc.getModuleName(),
|
pyFunc.getFunction(), pyFunc.getModuleName(),
|
||||||
pyFunc.getClassName(), pyFunc.getFunctionName(),
|
pyFunc.getFunctionName(), pyFunc.getFunctionInterface()
|
||||||
pyFunc.getFunctionInterface()
|
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
return new byte[0];
|
return new byte[0];
|
||||||
|
@ -88,10 +88,10 @@ public class GraphPbBuilder {
|
||||||
private byte[] serializePartition(Partition partition) {
|
private byte[] serializePartition(Partition partition) {
|
||||||
if (partition instanceof PythonPartition) {
|
if (partition instanceof PythonPartition) {
|
||||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||||
// partition_bytes, module_name, class_name, function_name
|
// partition_bytes, module_name, function_name
|
||||||
return serializer.serialize(Arrays.asList(
|
return serializer.serialize(Arrays.asList(
|
||||||
pythonPartition.getPartition(), pythonPartition.getModuleName(),
|
pythonPartition.getPartition(), pythonPartition.getModuleName(),
|
||||||
pythonPartition.getClassName(), pythonPartition.getFunctionName()
|
pythonPartition.getFunctionName()
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
return new byte[0];
|
return new byte[0];
|
||||||
|
|
|
@ -1,16 +1,21 @@
|
||||||
package io.ray.streaming.runtime.python;
|
package io.ray.streaming.runtime.python;
|
||||||
|
|
||||||
|
import com.google.common.base.Preconditions;
|
||||||
|
import com.google.common.primitives.Primitives;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
import io.ray.streaming.python.PythonFunction;
|
import io.ray.streaming.python.PythonFunction;
|
||||||
import io.ray.streaming.python.PythonPartition;
|
import io.ray.streaming.python.PythonPartition;
|
||||||
import io.ray.streaming.python.stream.PythonStreamSource;
|
import io.ray.streaming.python.stream.PythonStreamSource;
|
||||||
|
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||||
import io.ray.streaming.runtime.util.ReflectionUtils;
|
import io.ray.streaming.runtime.util.ReflectionUtils;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.msgpack.core.Preconditions;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -68,7 +73,7 @@ public class PythonGateway {
|
||||||
Preconditions.checkNotNull(streamingContext);
|
Preconditions.checkNotNull(streamingContext);
|
||||||
try {
|
try {
|
||||||
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
|
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
|
||||||
streamingContext, PythonFunction.fromFunction(pySourceFunc));
|
streamingContext, new PythonFunction(pySourceFunc));
|
||||||
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
|
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
|
||||||
return serializer.serialize(getReferenceId(pythonStreamSource));
|
return serializer.serialize(getReferenceId(pythonStreamSource));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -84,7 +89,7 @@ public class PythonGateway {
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] createPyFunc(byte[] pyFunc) {
|
public byte[] createPyFunc(byte[] pyFunc) {
|
||||||
PythonFunction function = PythonFunction.fromFunction(pyFunc);
|
PythonFunction function = new PythonFunction(pyFunc);
|
||||||
referenceMap.put(getReferenceId(function), function);
|
referenceMap.put(getReferenceId(function), function);
|
||||||
return serializer.serialize(getReferenceId(function));
|
return serializer.serialize(getReferenceId(function));
|
||||||
}
|
}
|
||||||
|
@ -98,15 +103,21 @@ public class PythonGateway {
|
||||||
public byte[] callFunction(byte[] paramsBytes) {
|
public byte[] callFunction(byte[] paramsBytes) {
|
||||||
try {
|
try {
|
||||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||||
params = processReferenceParameters(params);
|
params = processParameters(params);
|
||||||
LOG.info("callFunction params {}", params);
|
LOG.info("callFunction params {}", params);
|
||||||
String className = (String) params.get(0);
|
String className = (String) params.get(0);
|
||||||
String funcName = (String) params.get(1);
|
String funcName = (String) params.get(1);
|
||||||
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
||||||
Method method = ReflectionUtils.findMethod(clz, funcName);
|
Class[] paramsTypes = params.subList(2, params.size()).stream()
|
||||||
|
.map(Object::getClass).toArray(Class[]::new);
|
||||||
|
Method method = findMethod(clz, funcName, paramsTypes);
|
||||||
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
|
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
|
||||||
|
if (returnReference(result)) {
|
||||||
referenceMap.put(getReferenceId(result), result);
|
referenceMap.put(getReferenceId(result), result);
|
||||||
return serializer.serialize(getReferenceId(result));
|
return serializer.serialize(getReferenceId(result));
|
||||||
|
} else {
|
||||||
|
return serializer.serialize(result);
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -115,31 +126,78 @@ public class PythonGateway {
|
||||||
public byte[] callMethod(byte[] paramsBytes) {
|
public byte[] callMethod(byte[] paramsBytes) {
|
||||||
try {
|
try {
|
||||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||||
params = processReferenceParameters(params);
|
params = processParameters(params);
|
||||||
LOG.info("callMethod params {}", params);
|
LOG.info("callMethod params {}", params);
|
||||||
Object obj = params.get(0);
|
Object obj = params.get(0);
|
||||||
String methodName = (String) params.get(1);
|
String methodName = (String) params.get(1);
|
||||||
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
|
Class<?> clz = obj.getClass();
|
||||||
|
Class[] paramsTypes = params.subList(2, params.size()).stream()
|
||||||
|
.map(Object::getClass).toArray(Class[]::new);
|
||||||
|
Method method = findMethod(clz, methodName, paramsTypes);
|
||||||
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
|
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
|
||||||
|
if (returnReference(result)) {
|
||||||
referenceMap.put(getReferenceId(result), result);
|
referenceMap.put(getReferenceId(result), result);
|
||||||
return serializer.serialize(getReferenceId(result));
|
return serializer.serialize(getReferenceId(result));
|
||||||
|
} else {
|
||||||
|
return serializer.serialize(result);
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Object> processReferenceParameters(List<Object> params) {
|
private static Method findMethod(Class<?> cls, String methodName, Class[] paramsTypes) {
|
||||||
return params.stream().map(this::processReferenceParameter)
|
List<Method> methods = ReflectionUtils.findMethods(cls, methodName);
|
||||||
|
if (methods.size() == 1) {
|
||||||
|
return methods.get(0);
|
||||||
|
}
|
||||||
|
// Convert all params types to primitive types if it's boxed type
|
||||||
|
Class[] unwrappedTypes = Arrays.stream(paramsTypes)
|
||||||
|
.map((Function<Class, Class>) Primitives::unwrap)
|
||||||
|
.toArray(Class[]::new);
|
||||||
|
Optional<Method> any = methods.stream()
|
||||||
|
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
|
||||||
|
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
|
||||||
|
.findAny();
|
||||||
|
Preconditions.checkArgument(any.isPresent(),
|
||||||
|
String.format("Method %s with type %s doesn't exist on class %s",
|
||||||
|
methodName, Arrays.toString(paramsTypes), cls));
|
||||||
|
return any.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean returnReference(Object value) {
|
||||||
|
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte[] newInstance(byte[] classNameBytes) {
|
||||||
|
String className = (String) serializer.deserialize(classNameBytes);
|
||||||
|
try {
|
||||||
|
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
||||||
|
Object instance = clz.newInstance();
|
||||||
|
referenceMap.put(getReferenceId(instance), instance);
|
||||||
|
return serializer.serialize(getReferenceId(instance));
|
||||||
|
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format("Create instance for class %s failed", className), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<Object> processParameters(List<Object> params) {
|
||||||
|
return params.stream().map(this::processParameter)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private Object processReferenceParameter(Object o) {
|
private Object processParameter(Object o) {
|
||||||
if (o instanceof String) {
|
if (o instanceof String) {
|
||||||
Object value = referenceMap.get(o);
|
Object value = referenceMap.get(o);
|
||||||
if (value != null) {
|
if (value != null) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Since python can't represent byte/short, we convert all Byte/Short to Integer
|
||||||
|
if (o instanceof Byte || o instanceof Short) {
|
||||||
|
return ((Number) o).intValue();
|
||||||
|
}
|
||||||
return o;
|
return o;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler {
|
||||||
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
|
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
|
||||||
this.jobConfig = jobConfig;
|
this.jobConfig = jobConfig;
|
||||||
this.jobGraph = jobGraph;
|
this.jobGraph = jobGraph;
|
||||||
if (Ray.internal() == null) {
|
|
||||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
|
||||||
Ray.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
|
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
|
||||||
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
|
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
|
||||||
boolean hasPythonNode = executionNodes.stream()
|
boolean hasPythonNode = executionNodes.stream()
|
||||||
.allMatch(node -> node.getLanguage() == Language.PYTHON);
|
.anyMatch(node -> node.getLanguage() == Language.PYTHON);
|
||||||
RemoteCall.ExecutionGraph executionGraphPb = null;
|
RemoteCall.ExecutionGraph executionGraphPb = null;
|
||||||
if (hasPythonNode) {
|
if (hasPythonNode) {
|
||||||
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
|
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
|
||||||
|
|
|
@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule;
|
||||||
|
|
||||||
import io.ray.api.BaseActor;
|
import io.ray.api.BaseActor;
|
||||||
import io.ray.api.Ray;
|
import io.ray.api.Ray;
|
||||||
|
import io.ray.api.RayActor;
|
||||||
|
import io.ray.api.RayPyActor;
|
||||||
import io.ray.api.function.PyActorClass;
|
import io.ray.api.function.PyActorClass;
|
||||||
import io.ray.streaming.jobgraph.JobEdge;
|
import io.ray.streaming.jobgraph.JobEdge;
|
||||||
import io.ray.streaming.jobgraph.JobGraph;
|
import io.ray.streaming.jobgraph.JobGraph;
|
||||||
|
@ -15,8 +17,11 @@ import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
public class TaskAssignerImpl implements TaskAssigner {
|
public class TaskAssignerImpl implements TaskAssigner {
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assign an optimized logical plan to execution graph.
|
* Assign an optimized logical plan to execution graph.
|
||||||
|
@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner {
|
||||||
|
|
||||||
private BaseActor createWorker(JobVertex jobVertex) {
|
private BaseActor createWorker(JobVertex jobVertex) {
|
||||||
switch (jobVertex.getLanguage()) {
|
switch (jobVertex.getLanguage()) {
|
||||||
case PYTHON:
|
case PYTHON: {
|
||||||
return Ray.createActor(
|
RayPyActor worker = Ray.createActor(
|
||||||
new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
|
new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
|
||||||
case JAVA:
|
LOG.info("Created python worker {}", worker);
|
||||||
return Ray.createActor(JobWorker::new);
|
return worker;
|
||||||
|
}
|
||||||
|
case JAVA: {
|
||||||
|
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
|
||||||
|
LOG.info("Created java worker {}", worker);
|
||||||
|
return worker;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Unsupported language " + jobVertex.getLanguage());
|
"Unsupported language " + jobVertex.getLanguage());
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
|
import io.ray.streaming.message.KeyRecord;
|
||||||
|
import io.ray.streaming.message.Record;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A serializer for cross-lang serialization between java/python.
|
||||||
|
* TODO implements a more sophisticated serialization framework
|
||||||
|
*/
|
||||||
|
public class CrossLangSerializer implements Serializer {
|
||||||
|
private static final byte RECORD_TYPE_ID = 0;
|
||||||
|
private static final byte KEY_RECORD_TYPE_ID = 1;
|
||||||
|
|
||||||
|
private MsgPackSerializer msgPackSerializer = new MsgPackSerializer();
|
||||||
|
|
||||||
|
public byte[] serialize(Object object) {
|
||||||
|
Record record = (Record) object;
|
||||||
|
Object value = record.getValue();
|
||||||
|
Class<? extends Record> clz = record.getClass();
|
||||||
|
if (clz == Record.class) {
|
||||||
|
return msgPackSerializer.serialize(Arrays.asList(
|
||||||
|
RECORD_TYPE_ID, record.getStream(), value));
|
||||||
|
} else if (clz == KeyRecord.class) {
|
||||||
|
KeyRecord keyRecord = (KeyRecord) record;
|
||||||
|
Object key = keyRecord.getKey();
|
||||||
|
return msgPackSerializer.serialize(Arrays.asList(
|
||||||
|
KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value));
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
String.format("Serialize %s is unsupported.", record));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public Record deserialize(byte[] bytes) {
|
||||||
|
List list = (List) msgPackSerializer.deserialize(bytes);
|
||||||
|
Byte typeId = (Byte) list.get(0);
|
||||||
|
switch (typeId) {
|
||||||
|
case RECORD_TYPE_ID: {
|
||||||
|
String stream = (String) list.get(1);
|
||||||
|
Object value = list.get(2);
|
||||||
|
Record record = new Record(value);
|
||||||
|
record.setStream(stream);
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
case KEY_RECORD_TYPE_ID: {
|
||||||
|
String stream = (String) list.get(1);
|
||||||
|
Object key = list.get(2);
|
||||||
|
Object value = list.get(3);
|
||||||
|
KeyRecord keyRecord = new KeyRecord(key, value);
|
||||||
|
keyRecord.setStream(stream);
|
||||||
|
return keyRecord;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw new UnsupportedOperationException("Unsupported type " + typeId);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
|
import io.ray.runtime.serializer.FstSerializer;
|
||||||
|
|
||||||
|
public class JavaSerializer implements Serializer {
|
||||||
|
@Override
|
||||||
|
public byte[] serialize(Object object) {
|
||||||
|
return FstSerializer.encode(object);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <T> T deserialize(byte[] bytes) {
|
||||||
|
return FstSerializer.decode(bytes);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package io.ray.streaming.runtime.python;
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
import com.google.common.io.BaseEncoding;
|
import com.google.common.io.BaseEncoding;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -31,6 +31,10 @@ public class MsgPackSerializer {
|
||||||
Class<?> clz = obj.getClass();
|
Class<?> clz = obj.getClass();
|
||||||
if (clz == Boolean.class) {
|
if (clz == Boolean.class) {
|
||||||
packer.packBoolean((Boolean) obj);
|
packer.packBoolean((Boolean) obj);
|
||||||
|
} else if (clz == Byte.class) {
|
||||||
|
packer.packByte((Byte) obj);
|
||||||
|
} else if (clz == Short.class) {
|
||||||
|
packer.packShort((Short) obj);
|
||||||
} else if (clz == Integer.class) {
|
} else if (clz == Integer.class) {
|
||||||
packer.packInt((Integer) obj);
|
packer.packInt((Integer) obj);
|
||||||
} else if (clz == Long.class) {
|
} else if (clz == Long.class) {
|
||||||
|
@ -84,7 +88,11 @@ public class MsgPackSerializer {
|
||||||
return value.asBooleanValue().getBoolean();
|
return value.asBooleanValue().getBoolean();
|
||||||
case INTEGER:
|
case INTEGER:
|
||||||
IntegerValue iv = value.asIntegerValue();
|
IntegerValue iv = value.asIntegerValue();
|
||||||
if (iv.isInIntRange()) {
|
if (iv.isInByteRange()) {
|
||||||
|
return iv.toByte();
|
||||||
|
} else if (iv.isInShortRange()) {
|
||||||
|
return iv.toShort();
|
||||||
|
} else if (iv.isInIntRange()) {
|
||||||
return iv.toInt();
|
return iv.toInt();
|
||||||
} else if (iv.isInLongRange()) {
|
} else if (iv.isInLongRange()) {
|
||||||
return iv.toLong();
|
return iv.toLong();
|
|
@ -0,0 +1,12 @@
|
||||||
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
|
public interface Serializer {
|
||||||
|
byte CROSS_LANG_TYPE_ID = 0;
|
||||||
|
byte JAVA_TYPE_ID = 1;
|
||||||
|
byte PYTHON_TYPE_ID = 2;
|
||||||
|
|
||||||
|
byte[] serialize(Object object);
|
||||||
|
|
||||||
|
<T> T deserialize(byte[] bytes);
|
||||||
|
|
||||||
|
}
|
|
@ -20,7 +20,7 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public class ChannelCreationParametersBuilder {
|
public class ChannelCreationParametersBuilder {
|
||||||
|
|
||||||
public class Parameter {
|
public static class Parameter {
|
||||||
|
|
||||||
private ActorId actorId;
|
private ActorId actorId;
|
||||||
private FunctionDescriptor asyncFunctionDescriptor;
|
private FunctionDescriptor asyncFunctionDescriptor;
|
||||||
|
@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder {
|
||||||
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
|
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
|
||||||
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
|
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
|
||||||
} else {
|
} else {
|
||||||
Preconditions.checkArgument(false, "Invalid actor type");
|
throw new IllegalArgumentException("Invalid actor type");
|
||||||
}
|
}
|
||||||
parameters.add(parameter);
|
parameters.add(parameter);
|
||||||
}
|
}
|
||||||
|
@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
String str = "";
|
StringBuilder str = new StringBuilder();
|
||||||
for (Parameter param : parameters) {
|
for (Parameter param : parameters) {
|
||||||
str += param.toString();
|
str.append(param.toString());
|
||||||
}
|
}
|
||||||
return str;
|
return str.toString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ public class DataReader {
|
||||||
}
|
}
|
||||||
long timerInterval = Long.parseLong(
|
long timerInterval = Long.parseLong(
|
||||||
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
|
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
|
||||||
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
String channelType = conf.get(Config.CHANNEL_TYPE);
|
||||||
boolean isMock = false;
|
boolean isMock = false;
|
||||||
if (Config.MEMORY_CHANNEL.equals(channelType)) {
|
if (Config.MEMORY_CHANNEL.equals(channelType)) {
|
||||||
isMock = true;
|
isMock = true;
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class DataWriter {
|
||||||
Map<String, String> conf) {
|
Map<String, String> conf) {
|
||||||
Preconditions.checkArgument(!outputChannels.isEmpty());
|
Preconditions.checkArgument(!outputChannels.isEmpty());
|
||||||
Preconditions.checkArgument(outputChannels.size() == toActors.size());
|
Preconditions.checkArgument(outputChannels.size() == toActors.size());
|
||||||
ChannelCreationParametersBuilder initialParameters =
|
ChannelCreationParametersBuilder initParameters =
|
||||||
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
|
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
|
||||||
byte[][] outputChannelsBytes = outputChannels.stream()
|
byte[][] outputChannelsBytes = outputChannels.stream()
|
||||||
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
|
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
|
||||||
|
@ -47,13 +47,14 @@ public class DataWriter {
|
||||||
for (int i = 0; i < outputChannels.size(); i++) {
|
for (int i = 0; i < outputChannels.size(); i++) {
|
||||||
msgIds[i] = 0;
|
msgIds[i] = 0;
|
||||||
}
|
}
|
||||||
String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
String channelType = conf.get(Config.CHANNEL_TYPE);
|
||||||
boolean isMock = false;
|
boolean isMock = false;
|
||||||
if (Config.MEMORY_CHANNEL.equals(channelType)) {
|
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) {
|
||||||
isMock = true;
|
isMock = true;
|
||||||
|
LOGGER.info("Using memory channel");
|
||||||
}
|
}
|
||||||
this.nativeWriterPtr = createWriterNative(
|
this.nativeWriterPtr = createWriterNative(
|
||||||
initialParameters,
|
initParameters,
|
||||||
outputChannelsBytes,
|
outputChannelsBytes,
|
||||||
msgIds,
|
msgIds,
|
||||||
channelSize,
|
channelSize,
|
||||||
|
|
|
@ -19,6 +19,7 @@ public class ReflectionUtils {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For covariant return type, return the most specific method.
|
* For covariant return type, return the most specific method.
|
||||||
|
*
|
||||||
* @return all methods named by {@code methodName},
|
* @return all methods named by {@code methodName},
|
||||||
*/
|
*/
|
||||||
public static List<Method> findMethods(Class<?> cls, String methodName) {
|
public static List<Method> findMethods(Class<?> cls, String methodName) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package io.ray.streaming.runtime.worker;
|
package io.ray.streaming.runtime.worker;
|
||||||
|
|
||||||
|
import io.ray.api.Ray;
|
||||||
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
|
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||||
import io.ray.streaming.runtime.core.graph.ExecutionNode;
|
import io.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||||
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
|
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
|
||||||
|
@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext;
|
||||||
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
|
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
|
||||||
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
|
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
|
||||||
import io.ray.streaming.runtime.worker.tasks.StreamTask;
|
import io.ray.streaming.runtime.worker.tasks.StreamTask;
|
||||||
import io.ray.streaming.util.Config;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory;
|
||||||
*/
|
*/
|
||||||
public class JobWorker implements Serializable {
|
public class JobWorker implements Serializable {
|
||||||
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
|
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
|
||||||
|
// special flag to indicate this actor not ready
|
||||||
|
private static final byte[] NOT_READY_FLAG = new byte[4];
|
||||||
|
|
||||||
static {
|
static {
|
||||||
EnvUtil.loadNativeLibraries();
|
EnvUtil.loadNativeLibraries();
|
||||||
|
@ -54,11 +54,10 @@ public class JobWorker implements Serializable {
|
||||||
this.nodeType = executionNode.getNodeType();
|
this.nodeType = executionNode.getNodeType();
|
||||||
this.streamProcessor = ProcessBuilder
|
this.streamProcessor = ProcessBuilder
|
||||||
.buildProcessor(executionNode.getStreamOperator());
|
.buildProcessor(executionNode.getStreamOperator());
|
||||||
LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
|
LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.",
|
||||||
|
EnvUtil.getJvmPid(), taskId, streamProcessor);
|
||||||
|
|
||||||
String channelType = (String) this.config.getOrDefault(
|
if (!Ray.getRuntimeContext().isSingleProcess()) {
|
||||||
Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
|
|
||||||
if (channelType.equals(Config.NATIVE_CHANNEL)) {
|
|
||||||
transferHandler = new TransferHandler();
|
transferHandler = new TransferHandler();
|
||||||
}
|
}
|
||||||
task = createStreamTask();
|
task = createStreamTask();
|
||||||
|
@ -124,6 +123,9 @@ public class JobWorker implements Serializable {
|
||||||
* and receive result from this actor
|
* and receive result from this actor
|
||||||
*/
|
*/
|
||||||
public byte[] onReaderMessageSync(byte[] buffer) {
|
public byte[] onReaderMessageSync(byte[] buffer) {
|
||||||
|
if (transferHandler == null) {
|
||||||
|
return NOT_READY_FLAG;
|
||||||
|
}
|
||||||
return transferHandler.onReaderMessageSync(buffer);
|
return transferHandler.onReaderMessageSync(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,6 +141,9 @@ public class JobWorker implements Serializable {
|
||||||
* and receive result from this actor
|
* and receive result from this actor
|
||||||
*/
|
*/
|
||||||
public byte[] onWriterMessageSync(byte[] buffer) {
|
public byte[] onWriterMessageSync(byte[] buffer) {
|
||||||
|
if (transferHandler == null) {
|
||||||
|
return NOT_READY_FLAG;
|
||||||
|
}
|
||||||
return transferHandler.onWriterMessageSync(buffer);
|
return transferHandler.onWriterMessageSync(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package io.ray.streaming.runtime.worker.tasks;
|
package io.ray.streaming.runtime.worker.tasks;
|
||||||
|
|
||||||
import io.ray.runtime.serializer.Serializer;
|
|
||||||
import io.ray.streaming.runtime.core.processor.Processor;
|
import io.ray.streaming.runtime.core.processor.Processor;
|
||||||
|
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
|
||||||
|
import io.ray.streaming.runtime.serialization.JavaSerializer;
|
||||||
|
import io.ray.streaming.runtime.serialization.Serializer;
|
||||||
import io.ray.streaming.runtime.transfer.Message;
|
import io.ray.streaming.runtime.transfer.Message;
|
||||||
import io.ray.streaming.runtime.worker.JobWorker;
|
import io.ray.streaming.runtime.worker.JobWorker;
|
||||||
import io.ray.streaming.util.Config;
|
import io.ray.streaming.util.Config;
|
||||||
|
@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask {
|
||||||
private volatile boolean running = true;
|
private volatile boolean running = true;
|
||||||
private volatile boolean stopped = false;
|
private volatile boolean stopped = false;
|
||||||
private long readTimeoutMillis;
|
private long readTimeoutMillis;
|
||||||
|
private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
|
||||||
|
private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
|
||||||
|
|
||||||
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
|
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
|
||||||
super(taskId, processor, streamWorker);
|
super(taskId, processor, streamWorker);
|
||||||
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
|
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
|
||||||
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
|
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
|
||||||
|
javaSerializer = new JavaSerializer();
|
||||||
|
crossLangSerializer = new CrossLangSerializer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask {
|
||||||
while (running) {
|
while (running) {
|
||||||
Message item = reader.read(readTimeoutMillis);
|
Message item = reader.read(readTimeoutMillis);
|
||||||
if (item != null) {
|
if (item != null) {
|
||||||
byte[] bytes = new byte[item.body().remaining()];
|
byte[] bytes = new byte[item.body().remaining() - 1];
|
||||||
|
byte typeId = item.body().get();
|
||||||
item.body().get(bytes);
|
item.body().get(bytes);
|
||||||
Object obj = Serializer.decode(bytes, Object.class);
|
Object obj;
|
||||||
|
if (typeId == Serializer.JAVA_TYPE_ID) {
|
||||||
|
obj = javaSerializer.deserialize(bytes);
|
||||||
|
} else {
|
||||||
|
obj = crossLangSerializer.deserialize(bytes);
|
||||||
|
}
|
||||||
processor.process(obj);
|
processor.process(obj);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,6 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
public abstract class StreamTask implements Runnable {
|
public abstract class StreamTask implements Runnable {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
|
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
|
||||||
|
|
||||||
protected int taskId;
|
protected int taskId;
|
||||||
|
@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable {
|
||||||
String queueSize = worker.getConfig()
|
String queueSize = worker.getConfig()
|
||||||
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
|
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
|
||||||
queueConf.put(Config.CHANNEL_SIZE, queueSize);
|
queueConf.put(Config.CHANNEL_SIZE, queueSize);
|
||||||
String channelType = worker.getConfig()
|
String channelType = Ray.getRuntimeContext().isSingleProcess() ?
|
||||||
.getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
|
Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL;
|
||||||
queueConf.put(Config.CHANNEL_TYPE, channelType);
|
queueConf.put(Config.CHANNEL_TYPE, channelType);
|
||||||
|
|
||||||
ExecutionGraph executionGraph = worker.getExecutionGraph();
|
ExecutionGraph executionGraph = worker.getExecutionGraph();
|
||||||
|
@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable {
|
||||||
LOG.info("Create DataWriter succeed.");
|
LOG.info("Create DataWriter succeed.");
|
||||||
writers.put(edge, writer);
|
writers.put(edge, writer);
|
||||||
Partition partition = edge.getPartition();
|
Partition partition = edge.getPartition();
|
||||||
collectors.add(new OutputCollector(channelIDs, writer, partition));
|
collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable {
|
||||||
reader = new DataReader(channelIDs, inputActors, queueConf);
|
reader = new DataReader(channelIDs, inputActors, queueConf);
|
||||||
}
|
}
|
||||||
|
|
||||||
RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(),
|
RuntimeContext runtimeContext = new RayRuntimeContext(
|
||||||
worker.getConfig(), executionNode.getParallelism());
|
worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
|
||||||
|
|
||||||
processor.open(collectors, runtimeContext);
|
processor.open(collectors, runtimeContext);
|
||||||
|
|
||||||
|
|
|
@ -24,11 +24,13 @@ public abstract class BaseUnitTest {
|
||||||
|
|
||||||
@BeforeMethod
|
@BeforeMethod
|
||||||
public void testBegin(Method method) {
|
public void testBegin(Method method) {
|
||||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>");
|
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>",
|
||||||
|
method.getDeclaringClass(), method.getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterMethod
|
@AfterMethod
|
||||||
public void testEnd(Method method) {
|
public void testEnd(Method method) {
|
||||||
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>");
|
LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>",
|
||||||
|
method.getDeclaringClass(), method.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
|
||||||
|
|
||||||
public static JobGraph buildJobGraph() {
|
public static JobGraph buildJobGraph() {
|
||||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||||
Lists.newArrayList("a", "b", "c"));
|
Lists.newArrayList("a", "b", "c"));
|
||||||
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
package io.ray.streaming.runtime.demo;
|
||||||
|
|
||||||
|
import io.ray.api.Ray;
|
||||||
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
|
import io.ray.streaming.api.function.impl.FilterFunction;
|
||||||
|
import io.ray.streaming.api.function.impl.MapFunction;
|
||||||
|
import io.ray.streaming.api.stream.DataStreamSource;
|
||||||
|
import io.ray.streaming.runtime.BaseUnitTest;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.testng.annotations.Test;
|
||||||
|
|
||||||
|
public class HybridStreamTest extends BaseUnitTest implements Serializable {
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class);
|
||||||
|
|
||||||
|
public static class Mapper1 implements MapFunction<Object, Object> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object map(Object value) {
|
||||||
|
LOG.info("HybridStreamTest Mapper1 {}", value);
|
||||||
|
return value.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Filter1 implements FilterFunction<Object> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean filter(Object value) throws Exception {
|
||||||
|
LOG.info("HybridStreamTest Filter1 {}", value);
|
||||||
|
return !value.toString().contains("b");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testHybridDataStream() throws InterruptedException {
|
||||||
|
Ray.shutdown();
|
||||||
|
StreamingContext context = StreamingContext.buildContext();
|
||||||
|
DataStreamSource<String> streamSource =
|
||||||
|
DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c"));
|
||||||
|
streamSource
|
||||||
|
.map(x -> x + x)
|
||||||
|
.asPythonStream()
|
||||||
|
.map("ray.streaming.tests.test_hybrid_stream", "map_func1")
|
||||||
|
.filter("ray.streaming.tests.test_hybrid_stream", "filter_func1")
|
||||||
|
.asJavaStream()
|
||||||
|
.sink(x -> System.out.println("HybridStreamTest: " + x));
|
||||||
|
context.execute("HybridStreamTestJob");
|
||||||
|
TimeUnit.SECONDS.sleep(3);
|
||||||
|
context.stop();
|
||||||
|
LOG.info("HybridStreamTest succeed");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package io.ray.streaming.runtime.demo;
|
package io.ray.streaming.runtime.demo;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableMap;
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import io.ray.api.Ray;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||||
|
@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordCount() {
|
public void testWordCount() {
|
||||||
|
Ray.shutdown();
|
||||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||||
Map<String, String> config = new HashMap<>();
|
Map<String, String> config = new HashMap<>();
|
||||||
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
|
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
|
||||||
|
@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
||||||
streamingContext.withConfig(config);
|
streamingContext.withConfig(config);
|
||||||
List<String> text = new ArrayList<>();
|
List<String> text = new ArrayList<>();
|
||||||
text.add("hello world eagle eagle eagle");
|
text.add("hello world eagle eagle eagle");
|
||||||
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
|
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
|
||||||
streamSource
|
streamSource
|
||||||
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
||||||
String[] records = value.split(" ");
|
String[] records = value.split(" ");
|
||||||
|
@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
|
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
|
||||||
|
streamingContext.stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class WordAndCount implements Serializable {
|
private static class WordAndCount implements Serializable {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python;
|
||||||
import io.ray.streaming.api.stream.StreamSink;
|
import io.ray.streaming.api.stream.StreamSink;
|
||||||
import io.ray.streaming.jobgraph.JobGraph;
|
import io.ray.streaming.jobgraph.JobGraph;
|
||||||
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
import io.ray.streaming.jobgraph.JobGraphBuilder;
|
||||||
|
import io.ray.streaming.runtime.serialization.MsgPackSerializer;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
|
@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest {
|
||||||
|
|
||||||
public JobGraph buildDataSyncPlan() {
|
public JobGraph buildDataSyncPlan() {
|
||||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
DataStream<String> dataStream = DataStreamSource.fromCollection(streamingContext,
|
||||||
Lists.newArrayList("a", "b", "c"));
|
Lists.newArrayList("a", "b", "c"));
|
||||||
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
|
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
|
||||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
|
import static org.testng.Assert.assertEquals;
|
||||||
|
import static org.testng.Assert.assertTrue;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.builder.EqualsBuilder;
|
||||||
|
import io.ray.streaming.message.KeyRecord;
|
||||||
|
import io.ray.streaming.message.Record;
|
||||||
|
import org.testng.annotations.Test;
|
||||||
|
|
||||||
|
public class CrossLangSerializerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testSerialize() {
|
||||||
|
CrossLangSerializer serializer = new CrossLangSerializer();
|
||||||
|
Record record = new Record("value");
|
||||||
|
record.setStream("stream1");
|
||||||
|
assertTrue(EqualsBuilder.reflectionEquals(record,
|
||||||
|
serializer.deserialize(serializer.serialize(record))));
|
||||||
|
KeyRecord keyRecord = new KeyRecord("key", "value");
|
||||||
|
keyRecord.setStream("stream2");
|
||||||
|
assertEquals(keyRecord,
|
||||||
|
serializer.deserialize(serializer.serialize(keyRecord)));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,7 @@
|
||||||
package io.ray.streaming.runtime.python;
|
package io.ray.streaming.runtime.serialization;
|
||||||
|
|
||||||
|
import static org.testng.Assert.assertEquals;
|
||||||
|
import static org.testng.Assert.assertTrue;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -6,25 +9,37 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.testng.annotations.Test;
|
import org.testng.annotations.Test;
|
||||||
import static org.testng.Assert.assertEquals;
|
|
||||||
import static org.testng.Assert.assertTrue;
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public class MsgPackSerializerTest {
|
public class MsgPackSerializerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSerializeByte() {
|
||||||
|
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||||
|
|
||||||
|
assertEquals(serializer.deserialize(
|
||||||
|
serializer.serialize((byte)1)), (byte)1);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSerialize() {
|
public void testSerialize() {
|
||||||
MsgPackSerializer serializer = new MsgPackSerializer();
|
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||||
|
|
||||||
|
assertEquals(serializer.deserialize
|
||||||
|
(serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE);
|
||||||
|
assertEquals(serializer.deserialize(
|
||||||
|
serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE);
|
||||||
|
assertEquals(serializer.deserialize(
|
||||||
|
serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE);
|
||||||
|
|
||||||
Map map = new HashMap();
|
Map map = new HashMap();
|
||||||
List list = new ArrayList<>();
|
List list = new ArrayList<>();
|
||||||
list.add(null);
|
list.add(null);
|
||||||
list.add(true);
|
list.add(true);
|
||||||
list.add(1);
|
|
||||||
list.add(1.0d);
|
list.add(1.0d);
|
||||||
list.add("str");
|
list.add("str");
|
||||||
map.put("k1", "value1");
|
map.put("k1", "value1");
|
||||||
map.put("k2", 2);
|
map.put("k2", new HashMap<>());
|
||||||
map.put("k3", list);
|
map.put("k3", list);
|
||||||
byte[] bytes = serializer.serialize(map);
|
byte[] bytes = serializer.serialize(map);
|
||||||
Object o = serializer.deserialize(bytes);
|
Object o = serializer.deserialize(bytes);
|
|
@ -5,6 +5,7 @@ import io.ray.api.Ray;
|
||||||
import io.ray.api.RayActor;
|
import io.ray.api.RayActor;
|
||||||
import io.ray.api.options.ActorCreationOptions;
|
import io.ray.api.options.ActorCreationOptions;
|
||||||
import io.ray.api.options.ActorCreationOptions.Builder;
|
import io.ray.api.options.ActorCreationOptions.Builder;
|
||||||
|
import io.ray.runtime.config.RayConfig;
|
||||||
import io.ray.streaming.api.context.StreamingContext;
|
import io.ray.streaming.api.context.StreamingContext;
|
||||||
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
import io.ray.streaming.api.function.impl.FlatMapFunction;
|
||||||
import io.ray.streaming.api.function.impl.ReduceFunction;
|
import io.ray.streaming.api.function.impl.ReduceFunction;
|
||||||
|
@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
||||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||||
System.setProperty("ray.run-mode", "CLUSTER");
|
System.setProperty("ray.run-mode", "CLUSTER");
|
||||||
System.setProperty("ray.redirect-output", "true");
|
System.setProperty("ray.redirect-output", "true");
|
||||||
// ray init
|
RayConfig.reset();
|
||||||
Ray.init();
|
Ray.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
||||||
|
|
||||||
@Test(timeOut = 60000)
|
@Test(timeOut = 60000)
|
||||||
public void testWordCount() {
|
public void testWordCount() {
|
||||||
|
Ray.shutdown();
|
||||||
|
System.setProperty("ray.resources", "CPU:4,RES-A:4");
|
||||||
|
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||||
|
|
||||||
|
System.setProperty("ray.run-mode", "CLUSTER");
|
||||||
|
System.setProperty("ray.redirect-output", "true");
|
||||||
|
// ray init
|
||||||
|
Ray.init();
|
||||||
LOGGER.info("testWordCount");
|
LOGGER.info("testWordCount");
|
||||||
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
|
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
|
||||||
System.getProperty("ray.run-mode"));
|
System.getProperty("ray.run-mode"));
|
||||||
|
@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
||||||
streamingContext.withConfig(config);
|
streamingContext.withConfig(config);
|
||||||
List<String> text = new ArrayList<>();
|
List<String> text = new ArrayList<>();
|
||||||
text.add("hello world eagle eagle eagle");
|
text.add("hello world eagle eagle eagle");
|
||||||
DataStreamSource<String> streamSource = DataStreamSource.buildSource(streamingContext, text);
|
DataStreamSource<String> streamSource = DataStreamSource.fromCollection(streamingContext, text);
|
||||||
streamSource
|
streamSource
|
||||||
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
.flatMap((FlatMapFunction<String, WordAndCount>) (value, collector) -> {
|
||||||
String[] records = value.split(" ");
|
String[] records = value.split(" ");
|
||||||
|
@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
||||||
serializeResultToFile(resultFile, wordCount);
|
serializeResultToFile(resultFile, wordCount);
|
||||||
});
|
});
|
||||||
|
|
||||||
streamingContext.execute("testWordCount");
|
streamingContext.execute("testSQWordCount");
|
||||||
|
|
||||||
Map<String, Integer> checkWordCount =
|
Map<String, Integer> checkWordCount =
|
||||||
(Map<String, Integer>) deserializeResultFromFile(resultFile);
|
(Map<String, Integer>) deserializeResultFromFile(resultFile);
|
||||||
|
|
|
@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on
|
||||||
|
|
||||||
echo "Running streaming tests."
|
echo "Running streaming tests."
|
||||||
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
|
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
|
||||||
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml
|
org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml ||
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
|
if [ -z ${exit_code+x} ]; then
|
||||||
|
exit_code=0
|
||||||
|
fi
|
||||||
echo "Streaming TestNG results"
|
echo "Streaming TestNG results"
|
||||||
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
|
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
|
||||||
cat /tmp/ray_streaming_java_test_output/testng-results.xml
|
cat /tmp/ray_streaming_java_test_output/testng-results.xml
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
import typing
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from ray import Language
|
||||||
|
from ray.actor import ActorHandle
|
||||||
|
from ray.streaming import function
|
||||||
from ray.streaming import message
|
from ray.streaming import message
|
||||||
from ray.streaming import partition
|
from ray.streaming import partition
|
||||||
|
from ray.streaming.runtime import serialization
|
||||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter
|
from ray.streaming.runtime.transfer import ChannelID, DataWriter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,19 +34,46 @@ class CollectionCollector(Collector):
|
||||||
|
|
||||||
|
|
||||||
class OutputCollector(Collector):
|
class OutputCollector(Collector):
|
||||||
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
|
def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
|
||||||
|
target_actors: typing.List[ActorHandle],
|
||||||
partition_func: partition.Partition):
|
partition_func: partition.Partition):
|
||||||
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
|
||||||
self._writer = writer
|
self._writer = writer
|
||||||
|
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
|
||||||
|
self._target_languages = []
|
||||||
|
for actor in target_actors:
|
||||||
|
if actor._ray_actor_language == Language.PYTHON:
|
||||||
|
self._target_languages.append(function.Language.PYTHON)
|
||||||
|
elif actor._ray_actor_language == Language.JAVA:
|
||||||
|
self._target_languages.append(function.Language.JAVA)
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported language {}"
|
||||||
|
.format(actor._ray_actor_language))
|
||||||
self._partition_func = partition_func
|
self._partition_func = partition_func
|
||||||
|
self.python_serializer = serialization.PythonSerializer()
|
||||||
|
self.cross_lang_serializer = serialization.CrossLangSerializer()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Create OutputCollector, channel_ids {}, partition_func {}".format(
|
"Create OutputCollector, channel_ids {}, partition_func {}".format(
|
||||||
channel_ids, partition_func))
|
channel_ids, partition_func))
|
||||||
|
|
||||||
def collect(self, record):
|
def collect(self, record):
|
||||||
partitions = self._partition_func.partition(record,
|
partitions = self._partition_func \
|
||||||
len(self._channel_ids))
|
.partition(record, len(self._channel_ids))
|
||||||
serialized_message = pickle.dumps(record)
|
python_buffer = None
|
||||||
|
cross_lang_buffer = None
|
||||||
for partition_index in partitions:
|
for partition_index in partitions:
|
||||||
self._writer.write(self._channel_ids[partition_index],
|
if self._target_languages[partition_index] == \
|
||||||
serialized_message)
|
function.Language.PYTHON:
|
||||||
|
# avoid repeated serialization
|
||||||
|
if python_buffer is None:
|
||||||
|
python_buffer = self.python_serializer.serialize(record)
|
||||||
|
self._writer.write(
|
||||||
|
self._channel_ids[partition_index],
|
||||||
|
serialization._PYTHON_TYPE_ID + python_buffer)
|
||||||
|
else:
|
||||||
|
# avoid repeated serialization
|
||||||
|
if cross_lang_buffer is None:
|
||||||
|
cross_lang_buffer = self.cross_lang_serializer.serialize(
|
||||||
|
record)
|
||||||
|
self._writer.write(
|
||||||
|
self._channel_ids[partition_index],
|
||||||
|
serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from abc import ABC
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ray.streaming import function
|
from ray.streaming import function
|
||||||
from ray.streaming import partition
|
from ray.streaming import partition
|
||||||
|
@ -19,7 +19,6 @@ class Stream(ABC):
|
||||||
self.streaming_context = input_stream.streaming_context
|
self.streaming_context = input_stream.streaming_context
|
||||||
else:
|
else:
|
||||||
self.streaming_context = streaming_context
|
self.streaming_context = streaming_context
|
||||||
self.parallelism = 1
|
|
||||||
|
|
||||||
def get_streaming_context(self):
|
def get_streaming_context(self):
|
||||||
return self.streaming_context
|
return self.streaming_context
|
||||||
|
@ -29,7 +28,8 @@ class Stream(ABC):
|
||||||
Returns:
|
Returns:
|
||||||
the parallelism of this transformation
|
the parallelism of this transformation
|
||||||
"""
|
"""
|
||||||
return self.parallelism
|
return self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, "getParallelism")
|
||||||
|
|
||||||
def set_parallelism(self, parallelism: int):
|
def set_parallelism(self, parallelism: int):
|
||||||
"""Sets the parallelism of this transformation
|
"""Sets the parallelism of this transformation
|
||||||
|
@ -40,7 +40,6 @@ class Stream(ABC):
|
||||||
Returns:
|
Returns:
|
||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
self.parallelism = parallelism
|
|
||||||
self._gateway_client(). \
|
self._gateway_client(). \
|
||||||
call_method(self._j_stream, "setParallelism", parallelism)
|
call_method(self._j_stream, "setParallelism", parallelism)
|
||||||
return self
|
return self
|
||||||
|
@ -60,6 +59,10 @@ class Stream(ABC):
|
||||||
return self._gateway_client(). \
|
return self._gateway_client(). \
|
||||||
call_method(self._j_stream, "getId")
|
call_method(self._j_stream, "getId")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_language(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def _gateway_client(self):
|
def _gateway_client(self):
|
||||||
return self.get_streaming_context()._gateway_client
|
return self.get_streaming_context()._gateway_client
|
||||||
|
|
||||||
|
@ -75,6 +78,9 @@ class DataStream(Stream):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_stream, j_stream, streaming_context=streaming_context)
|
input_stream, j_stream, streaming_context=streaming_context)
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.PYTHON
|
||||||
|
|
||||||
def map(self, func):
|
def map(self, func):
|
||||||
"""
|
"""
|
||||||
Applies a Map transformation on a :class:`DataStream`.
|
Applies a Map transformation on a :class:`DataStream`.
|
||||||
|
@ -158,6 +164,7 @@ class DataStream(Stream):
|
||||||
Returns:
|
Returns:
|
||||||
A KeyDataStream
|
A KeyDataStream
|
||||||
"""
|
"""
|
||||||
|
self._check_partition_call()
|
||||||
if not isinstance(func, function.KeyFunction):
|
if not isinstance(func, function.KeyFunction):
|
||||||
func = function.SimpleKeyFunction(func)
|
func = function.SimpleKeyFunction(func)
|
||||||
j_func = self._gateway_client().create_py_func(
|
j_func = self._gateway_client().create_py_func(
|
||||||
|
@ -175,6 +182,7 @@ class DataStream(Stream):
|
||||||
Returns:
|
Returns:
|
||||||
The DataStream with broadcast partitioning set.
|
The DataStream with broadcast partitioning set.
|
||||||
"""
|
"""
|
||||||
|
self._check_partition_call()
|
||||||
self._gateway_client().call_method(self._j_stream, "broadcast")
|
self._gateway_client().call_method(self._j_stream, "broadcast")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -191,6 +199,7 @@ class DataStream(Stream):
|
||||||
Returns:
|
Returns:
|
||||||
The DataStream with specified partitioning set.
|
The DataStream with specified partitioning set.
|
||||||
"""
|
"""
|
||||||
|
self._check_partition_call()
|
||||||
if not isinstance(partition_func, partition.Partition):
|
if not isinstance(partition_func, partition.Partition):
|
||||||
partition_func = partition.SimplePartition(partition_func)
|
partition_func = partition.SimplePartition(partition_func)
|
||||||
j_partition = self._gateway_client().create_py_func(
|
j_partition = self._gateway_client().create_py_func(
|
||||||
|
@ -199,6 +208,16 @@ class DataStream(Stream):
|
||||||
call_method(self._j_stream, "partitionBy", j_partition)
|
call_method(self._j_stream, "partitionBy", j_partition)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _check_partition_call(self):
|
||||||
|
"""
|
||||||
|
If parent stream is a java stream, we can't call partition related
|
||||||
|
methods in the python stream
|
||||||
|
"""
|
||||||
|
if self.input_stream is not None and \
|
||||||
|
self.input_stream.get_language() == function.Language.JAVA:
|
||||||
|
raise Exception("Partition related methods can't be called on a "
|
||||||
|
"python stream if parent stream is a java stream.")
|
||||||
|
|
||||||
def sink(self, func):
|
def sink(self, func):
|
||||||
"""
|
"""
|
||||||
Create a StreamSink with the given sink.
|
Create a StreamSink with the given sink.
|
||||||
|
@ -217,8 +236,97 @@ class DataStream(Stream):
|
||||||
call_method(self._j_stream, "sink", j_func)
|
call_method(self._j_stream, "sink", j_func)
|
||||||
return StreamSink(self, j_stream, func)
|
return StreamSink(self, j_stream, func)
|
||||||
|
|
||||||
|
def as_java_stream(self):
|
||||||
|
"""
|
||||||
|
Convert this stream as a java JavaDataStream.
|
||||||
|
The converted stream and this stream are the same logical stream,
|
||||||
|
which has same stream id. Changes in converted stream will be reflected
|
||||||
|
in this stream and vice versa.
|
||||||
|
"""
|
||||||
|
j_stream = self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, "asJavaStream")
|
||||||
|
return JavaDataStream(self, j_stream)
|
||||||
|
|
||||||
class KeyDataStream(Stream):
|
|
||||||
|
class JavaDataStream(Stream):
|
||||||
|
"""
|
||||||
|
Represents a stream of data which applies a transformation executed by
|
||||||
|
java. It's also a wrapper of java
|
||||||
|
`org.ray.streaming.api.stream.DataStream`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_stream, j_stream, streaming_context=None):
|
||||||
|
super().__init__(
|
||||||
|
input_stream, j_stream, streaming_context=streaming_context)
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.JAVA
|
||||||
|
|
||||||
|
def map(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.map"""
|
||||||
|
return JavaDataStream(self, self._unary_call("map", java_func_class))
|
||||||
|
|
||||||
|
def flat_map(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
|
||||||
|
return JavaDataStream(self, self._unary_call("flatMap",
|
||||||
|
java_func_class))
|
||||||
|
|
||||||
|
def filter(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.filter"""
|
||||||
|
return JavaDataStream(self, self._unary_call("filter",
|
||||||
|
java_func_class))
|
||||||
|
|
||||||
|
def key_by(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
|
||||||
|
self._check_partition_call()
|
||||||
|
return JavaKeyDataStream(self,
|
||||||
|
self._unary_call("keyBy", java_func_class))
|
||||||
|
|
||||||
|
def broadcast(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
|
||||||
|
self._check_partition_call()
|
||||||
|
return JavaDataStream(self,
|
||||||
|
self._unary_call("broadcast", java_func_class))
|
||||||
|
|
||||||
|
def partition_by(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
|
||||||
|
self._check_partition_call()
|
||||||
|
return JavaDataStream(self,
|
||||||
|
self._unary_call("partitionBy", java_func_class))
|
||||||
|
|
||||||
|
def sink(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.DataStream.sink"""
|
||||||
|
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
|
||||||
|
|
||||||
|
def as_python_stream(self):
|
||||||
|
"""
|
||||||
|
Convert this stream as a python DataStream.
|
||||||
|
The converted stream and this stream are the same logical stream,
|
||||||
|
which has same stream id. Changes in converted stream will be reflected
|
||||||
|
in this stream and vice versa.
|
||||||
|
"""
|
||||||
|
j_stream = self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, "asPythonStream")
|
||||||
|
return DataStream(self, j_stream)
|
||||||
|
|
||||||
|
def _check_partition_call(self):
|
||||||
|
"""
|
||||||
|
If parent stream is a python stream, we can't call partition related
|
||||||
|
methods in the java stream
|
||||||
|
"""
|
||||||
|
if self.input_stream is not None and \
|
||||||
|
self.input_stream.get_language() == function.Language.PYTHON:
|
||||||
|
raise Exception("Partition related methods can't be called on a"
|
||||||
|
"java stream if parent stream is a python stream.")
|
||||||
|
|
||||||
|
def _unary_call(self, func_name, java_func_class):
|
||||||
|
j_func = self._gateway_client().new_instance(java_func_class)
|
||||||
|
j_stream = self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, func_name, j_func)
|
||||||
|
return j_stream
|
||||||
|
|
||||||
|
|
||||||
|
class KeyDataStream(DataStream):
|
||||||
"""Represents a DataStream returned by a key-by operation.
|
"""Represents a DataStream returned by a key-by operation.
|
||||||
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
|
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
|
||||||
"""
|
"""
|
||||||
|
@ -251,6 +359,43 @@ class KeyDataStream(Stream):
|
||||||
call_method(self._j_stream, "reduce", j_func)
|
call_method(self._j_stream, "reduce", j_func)
|
||||||
return DataStream(self, j_stream)
|
return DataStream(self, j_stream)
|
||||||
|
|
||||||
|
def as_java_stream(self):
|
||||||
|
"""
|
||||||
|
Convert this stream as a java KeyDataStream.
|
||||||
|
The converted stream and this stream are the same logical stream,
|
||||||
|
which has same stream id. Changes in converted stream will be reflected
|
||||||
|
in this stream and vice versa.
|
||||||
|
"""
|
||||||
|
j_stream = self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, "asJavaStream")
|
||||||
|
return JavaKeyDataStream(self, j_stream)
|
||||||
|
|
||||||
|
|
||||||
|
class JavaKeyDataStream(JavaDataStream):
|
||||||
|
"""
|
||||||
|
Represents a DataStream returned by a key-by operation in java.
|
||||||
|
Wrapper of org.ray.streaming.api.stream.KeyDataStream
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_stream, j_stream):
|
||||||
|
super().__init__(input_stream, j_stream)
|
||||||
|
|
||||||
|
def reduce(self, java_func_class):
|
||||||
|
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
|
||||||
|
return JavaDataStream(self,
|
||||||
|
super()._unary_call("reduce", java_func_class))
|
||||||
|
|
||||||
|
def as_python_stream(self):
|
||||||
|
"""
|
||||||
|
Convert this stream as a python KeyDataStream.
|
||||||
|
The converted stream and this stream are the same logical stream,
|
||||||
|
which has same stream id. Changes in converted stream will be reflected
|
||||||
|
in this stream and vice versa.
|
||||||
|
"""
|
||||||
|
j_stream = self._gateway_client(). \
|
||||||
|
call_method(self._j_stream, "asPythonStream")
|
||||||
|
return KeyDataStream(self, j_stream)
|
||||||
|
|
||||||
|
|
||||||
class StreamSource(DataStream):
|
class StreamSource(DataStream):
|
||||||
"""Represents a source of the DataStream.
|
"""Represents a source of the DataStream.
|
||||||
|
@ -261,9 +406,12 @@ class StreamSource(DataStream):
|
||||||
super().__init__(None, j_stream, streaming_context=streaming_context)
|
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||||
self.source_func = source_func
|
self.source_func = source_func
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.PYTHON
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_source(streaming_context, func):
|
def build_source(streaming_context, func):
|
||||||
"""Build a StreamSource source from a collection.
|
"""Build a StreamSource source from a source function.
|
||||||
Args:
|
Args:
|
||||||
streaming_context: Stream context
|
streaming_context: Stream context
|
||||||
func: A instance of `SourceFunction`
|
func: A instance of `SourceFunction`
|
||||||
|
@ -275,6 +423,34 @@ class StreamSource(DataStream):
|
||||||
return StreamSource(j_stream, streaming_context, func)
|
return StreamSource(j_stream, streaming_context, func)
|
||||||
|
|
||||||
|
|
||||||
|
class JavaStreamSource(JavaDataStream):
|
||||||
|
"""Represents a source of the java DataStream.
|
||||||
|
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, j_stream, streaming_context):
|
||||||
|
super().__init__(None, j_stream, streaming_context=streaming_context)
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.JAVA
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_source(streaming_context, java_source_func_class):
|
||||||
|
"""Build a java StreamSource source from a java source function.
|
||||||
|
Args:
|
||||||
|
streaming_context: Stream context
|
||||||
|
java_source_func_class: qualified class name of java SourceFunction
|
||||||
|
Returns:
|
||||||
|
A java StreamSource
|
||||||
|
"""
|
||||||
|
j_func = streaming_context._gateway_client() \
|
||||||
|
.new_instance(java_source_func_class)
|
||||||
|
j_stream = streaming_context._gateway_client() \
|
||||||
|
.call_function("org.ray.streaming.api.stream.DataStreamSource"
|
||||||
|
"fromSource", streaming_context._j_ctx, j_func)
|
||||||
|
return JavaStreamSource(j_stream, streaming_context)
|
||||||
|
|
||||||
|
|
||||||
class StreamSink(Stream):
|
class StreamSink(Stream):
|
||||||
"""Represents a sink of the DataStream.
|
"""Represents a sink of the DataStream.
|
||||||
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
|
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
|
||||||
|
@ -282,3 +458,18 @@ class StreamSink(Stream):
|
||||||
|
|
||||||
def __init__(self, input_stream, j_stream, func):
|
def __init__(self, input_stream, j_stream, func):
|
||||||
super().__init__(input_stream, j_stream)
|
super().__init__(input_stream, j_stream)
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.PYTHON
|
||||||
|
|
||||||
|
|
||||||
|
class JavaStreamSink(Stream):
|
||||||
|
"""Represents a sink of the java DataStream.
|
||||||
|
Wrapper of java org.ray.streaming.api.stream.StreamSink
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_stream, j_stream):
|
||||||
|
super().__init__(input_stream, j_stream)
|
||||||
|
|
||||||
|
def get_language(self):
|
||||||
|
return function.Language.JAVA
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
|
import enum
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import typing
|
import typing
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ray import cloudpickle
|
from ray import cloudpickle
|
||||||
from ray.streaming.runtime import gateway_client
|
from ray.streaming.runtime import gateway_client
|
||||||
|
|
||||||
|
|
||||||
|
class Language(enum.Enum):
|
||||||
|
JAVA = 0
|
||||||
|
PYTHON = 1
|
||||||
|
|
||||||
|
|
||||||
class Function(ABC):
|
class Function(ABC):
|
||||||
"""The base interface for all user-defined functions."""
|
"""The base interface for all user-defined functions."""
|
||||||
|
|
||||||
|
@ -60,6 +66,7 @@ class MapFunction(Function):
|
||||||
for each input element.
|
for each input element.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def map(self, value):
|
def map(self, value):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -70,6 +77,7 @@ class FlatMapFunction(Function):
|
||||||
transform them into zero, one, or more elements.
|
transform them into zero, one, or more elements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def flat_map(self, value, collector):
|
def flat_map(self, value, collector):
|
||||||
"""Takes an element from the input data set and transforms it into zero,
|
"""Takes an element from the input data set and transforms it into zero,
|
||||||
one, or more elements.
|
one, or more elements.
|
||||||
|
@ -87,6 +95,7 @@ class FilterFunction(Function):
|
||||||
The predicate decides whether to keep the element, or to discard it.
|
The predicate decides whether to keep the element, or to discard it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def filter(self, value):
|
def filter(self, value):
|
||||||
"""The filter function that evaluates the predicate.
|
"""The filter function that evaluates the predicate.
|
||||||
|
|
||||||
|
@ -106,6 +115,7 @@ class KeyFunction(Function):
|
||||||
deterministic key for that object.
|
deterministic key for that object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def key_by(self, value):
|
def key_by(self, value):
|
||||||
"""User-defined function that deterministically extracts the key from
|
"""User-defined function that deterministically extracts the key from
|
||||||
an object.
|
an object.
|
||||||
|
@ -126,6 +136,7 @@ class ReduceFunction(Function):
|
||||||
them into one.
|
them into one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def reduce(self, old_value, new_value):
|
def reduce(self, old_value, new_value):
|
||||||
"""
|
"""
|
||||||
The core method of ReduceFunction, combining two values into one value
|
The core method of ReduceFunction, combining two values into one value
|
||||||
|
@ -145,6 +156,7 @@ class ReduceFunction(Function):
|
||||||
class SinkFunction(Function):
|
class SinkFunction(Function):
|
||||||
"""Interface for implementing user defined sink functionality."""
|
"""Interface for implementing user defined sink functionality."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def sink(self, value):
|
def sink(self, value):
|
||||||
"""Writes the given value to the sink. This function is called for
|
"""Writes the given value to the sink. This function is called for
|
||||||
every record."""
|
every record."""
|
||||||
|
@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
|
||||||
Returns:
|
Returns:
|
||||||
a streaming function
|
a streaming function
|
||||||
"""
|
"""
|
||||||
function_bytes, module_name, class_name, function_name, function_interface\
|
assert len(descriptor_func_bytes) > 0
|
||||||
|
function_bytes, module_name, function_name, function_interface\
|
||||||
= gateway_client.deserialize(descriptor_func_bytes)
|
= gateway_client.deserialize(descriptor_func_bytes)
|
||||||
if function_bytes:
|
if function_bytes:
|
||||||
return deserialize(function_bytes)
|
return deserialize(function_bytes)
|
||||||
|
@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
|
||||||
assert function_interface
|
assert function_interface
|
||||||
function_interface = getattr(sys.modules[__name__], function_interface)
|
function_interface = getattr(sys.modules[__name__], function_interface)
|
||||||
mod = importlib.import_module(module_name)
|
mod = importlib.import_module(module_name)
|
||||||
if class_name:
|
|
||||||
assert function_name is None
|
|
||||||
cls = getattr(mod, class_name)
|
|
||||||
assert issubclass(cls, function_interface)
|
|
||||||
return cls()
|
|
||||||
else:
|
|
||||||
assert function_name
|
assert function_name
|
||||||
func = getattr(mod, function_name)
|
func = getattr(mod, function_name)
|
||||||
|
# If func is a python function, user function is a simple python
|
||||||
|
# function, which will be wrapped as a SimpleXXXFunction.
|
||||||
|
# If func is a python class, user function is a sub class
|
||||||
|
# of XXXFunction.
|
||||||
|
if inspect.isfunction(func):
|
||||||
simple_func_class = _get_simple_function_class(function_interface)
|
simple_func_class = _get_simple_function_class(function_interface)
|
||||||
return simple_func_class(func)
|
return simple_func_class(func)
|
||||||
|
else:
|
||||||
|
assert issubclass(func, function_interface)
|
||||||
|
return func()
|
||||||
|
|
||||||
|
|
||||||
def _get_simple_function_class(function_interface):
|
def _get_simple_function_class(function_interface):
|
||||||
|
|
|
@ -8,6 +8,14 @@ class Record:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "Record(%s)".format(self.value)
|
return "Record(%s)".format(self.value)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if type(self) is type(other):
|
||||||
|
return (self.stream, self.value) == (other.stream, other.value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.stream, self.value))
|
||||||
|
|
||||||
|
|
||||||
class KeyRecord(Record):
|
class KeyRecord(Record):
|
||||||
"""Data record in a keyed data stream"""
|
"""Data record in a keyed data stream"""
|
||||||
|
@ -15,3 +23,12 @@ class KeyRecord(Record):
|
||||||
def __init__(self, key, value):
|
def __init__(self, key, value):
|
||||||
super().__init__(value)
|
super().__init__(value)
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if type(self) is type(other):
|
||||||
|
return (self.stream, self.key, self.value) ==\
|
||||||
|
(other.stream, other.key, other.value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.stream, self.key, self.value))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ray import cloudpickle
|
from ray import cloudpickle
|
||||||
|
@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
|
||||||
Returns:
|
Returns:
|
||||||
partition function
|
partition function
|
||||||
"""
|
"""
|
||||||
partition_bytes, module_name, class_name, function_name =\
|
assert len(descriptor_partition_bytes) > 0
|
||||||
|
partition_bytes, module_name, function_name =\
|
||||||
gateway_client.deserialize(descriptor_partition_bytes)
|
gateway_client.deserialize(descriptor_partition_bytes)
|
||||||
if partition_bytes:
|
if partition_bytes:
|
||||||
return deserialize(partition_bytes)
|
return deserialize(partition_bytes)
|
||||||
else:
|
else:
|
||||||
assert module_name
|
assert module_name
|
||||||
mod = importlib.import_module(module_name)
|
mod = importlib.import_module(module_name)
|
||||||
# If class_name is not None, user partition is a sub class
|
|
||||||
# of Partition.
|
|
||||||
# If function_name is not None, user partition is a simple python
|
|
||||||
# function, which will be wrapped as a SimplePartition.
|
|
||||||
if class_name:
|
|
||||||
assert function_name is None
|
|
||||||
cls = getattr(mod, class_name)
|
|
||||||
return cls()
|
|
||||||
else:
|
|
||||||
assert function_name
|
assert function_name
|
||||||
func = getattr(mod, function_name)
|
func = getattr(mod, function_name)
|
||||||
|
# If func is a python function, user partition is a simple python
|
||||||
|
# function, which will be wrapped as a SimplePartition.
|
||||||
|
# If func is a python class, user partition is a sub class
|
||||||
|
# of Partition.
|
||||||
|
if inspect.isfunction(func):
|
||||||
return SimplePartition(func)
|
return SimplePartition(func)
|
||||||
|
else:
|
||||||
|
assert issubclass(func, Partition)
|
||||||
|
return func()
|
||||||
|
|
|
@ -55,6 +55,11 @@ class GatewayClient:
|
||||||
call = self._python_gateway_actor.callMethod.remote(java_params)
|
call = self._python_gateway_actor.callMethod.remote(java_params)
|
||||||
return deserialize(ray.get(call))
|
return deserialize(ray.get(call))
|
||||||
|
|
||||||
|
def new_instance(self, java_class_name):
|
||||||
|
call = self._python_gateway_actor.newInstance.remote(
|
||||||
|
serialize(java_class_name))
|
||||||
|
return deserialize(ray.get(call))
|
||||||
|
|
||||||
|
|
||||||
def serialize(obj) -> bytes:
|
def serialize(obj) -> bytes:
|
||||||
"""Serialize a python object which can be deserialized by `PythonGateway`
|
"""Serialize a python object which can be deserialized by `PythonGateway`
|
||||||
|
|
|
@ -53,7 +53,9 @@ class ExecutionEdge:
|
||||||
self.src_node_id = edge_pb.src_node_id
|
self.src_node_id = edge_pb.src_node_id
|
||||||
self.target_node_id = edge_pb.target_node_id
|
self.target_node_id = edge_pb.target_node_id
|
||||||
partition_bytes = edge_pb.partition
|
partition_bytes = edge_pb.partition
|
||||||
if language == Language.PYTHON:
|
# Sink node doesn't have partition function,
|
||||||
|
# so we only deserialize partition_bytes when it's not None or empty
|
||||||
|
if language == Language.PYTHON and partition_bytes:
|
||||||
self.partition = partition.load_partition(partition_bytes)
|
self.partition = partition.load_partition(partition_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
|
57
streaming/python/runtime/serialization.py
Normal file
57
streaming/python/runtime/serialization.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import pickle
|
||||||
|
import msgpack
|
||||||
|
from ray.streaming import message
|
||||||
|
|
||||||
|
_RECORD_TYPE_ID = 0
|
||||||
|
_KEY_RECORD_TYPE_ID = 1
|
||||||
|
_CROSS_LANG_TYPE_ID = b"0"
|
||||||
|
_JAVA_TYPE_ID = b"1"
|
||||||
|
_PYTHON_TYPE_ID = b"2"
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def serialize(self, obj):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def deserialize(self, serialized_bytes):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PythonSerializer(Serializer):
|
||||||
|
def serialize(self, obj):
|
||||||
|
return pickle.dumps(obj)
|
||||||
|
|
||||||
|
def deserialize(self, serialized_bytes):
|
||||||
|
return pickle.loads(serialized_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossLangSerializer(Serializer):
|
||||||
|
"""Serialize stream element between java/python"""
|
||||||
|
|
||||||
|
def serialize(self, obj):
|
||||||
|
if type(obj) is message.Record:
|
||||||
|
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
|
||||||
|
elif type(obj) is message.KeyRecord:
|
||||||
|
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported value {}".format(obj))
|
||||||
|
return msgpack.packb(fields, use_bin_type=True)
|
||||||
|
|
||||||
|
def deserialize(self, data):
|
||||||
|
fields = msgpack.unpackb(data, raw=False)
|
||||||
|
if fields[0] == _RECORD_TYPE_ID:
|
||||||
|
stream, value = fields[1:]
|
||||||
|
record = message.Record(value)
|
||||||
|
record.stream = stream
|
||||||
|
return record
|
||||||
|
elif fields[0] == _KEY_RECORD_TYPE_ID:
|
||||||
|
stream, key, value = fields[1:]
|
||||||
|
key_record = message.KeyRecord(key, value)
|
||||||
|
key_record.stream = stream
|
||||||
|
return key_record
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported type id {}, type {}".format(
|
||||||
|
fields[0], type(fields[0])))
|
|
@ -1,11 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ray.streaming.collector import OutputCollector
|
from ray.streaming.collector import OutputCollector
|
||||||
from ray.streaming.config import Config
|
from ray.streaming.config import Config
|
||||||
from ray.streaming.context import RuntimeContextImpl
|
from ray.streaming.context import RuntimeContextImpl
|
||||||
|
from ray.streaming.runtime import serialization
|
||||||
|
from ray.streaming.runtime.serialization import \
|
||||||
|
PythonSerializer, CrossLangSerializer
|
||||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -38,36 +40,40 @@ class StreamTask(ABC):
|
||||||
# writers
|
# writers
|
||||||
collectors = []
|
collectors = []
|
||||||
for edge in execution_node.output_edges:
|
for edge in execution_node.output_edges:
|
||||||
output_actor_ids = {}
|
output_actors_map = {}
|
||||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||||
edge.target_node_id)
|
edge.target_node_id)
|
||||||
for target_task_id, target_actor in task_id2_worker.items():
|
for target_task_id, target_actor in task_id2_worker.items():
|
||||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||||
execution_graph.build_time())
|
execution_graph.build_time())
|
||||||
output_actor_ids[channel_name] = target_actor
|
output_actors_map[channel_name] = target_actor
|
||||||
if len(output_actor_ids) > 0:
|
if len(output_actors_map) > 0:
|
||||||
channel_ids = list(output_actor_ids.keys())
|
channel_ids = list(output_actors_map.keys())
|
||||||
to_actor_ids = list(output_actor_ids.values())
|
target_actors = list(output_actors_map.values())
|
||||||
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
|
logger.info(
|
||||||
logger.info("Create DataWriter succeed.")
|
"Create DataWriter channel_ids {}, target_actors {}."
|
||||||
|
.format(channel_ids, target_actors))
|
||||||
|
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
||||||
self.writers[edge] = writer
|
self.writers[edge] = writer
|
||||||
collectors.append(
|
collectors.append(
|
||||||
OutputCollector(channel_ids, writer, edge.partition))
|
OutputCollector(writer, channel_ids, target_actors,
|
||||||
|
edge.partition))
|
||||||
|
|
||||||
# readers
|
# readers
|
||||||
input_actor_ids = {}
|
input_actor_map = {}
|
||||||
for edge in execution_node.input_edges:
|
for edge in execution_node.input_edges:
|
||||||
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
|
||||||
edge.src_node_id)
|
edge.src_node_id)
|
||||||
for src_task_id, src_actor in task_id2_worker.items():
|
for src_task_id, src_actor in task_id2_worker.items():
|
||||||
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
|
||||||
execution_graph.build_time())
|
execution_graph.build_time())
|
||||||
input_actor_ids[channel_name] = src_actor
|
input_actor_map[channel_name] = src_actor
|
||||||
if len(input_actor_ids) > 0:
|
if len(input_actor_map) > 0:
|
||||||
channel_ids = list(input_actor_ids.keys())
|
channel_ids = list(input_actor_map.keys())
|
||||||
from_actor_ids = list(input_actor_ids.values())
|
from_actors = list(input_actor_map.values())
|
||||||
logger.info("Create DataReader, channels {}.".format(channel_ids))
|
logger.info("Create DataReader, channels {}, input_actors {}."
|
||||||
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
|
.format(channel_ids, from_actors))
|
||||||
|
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
||||||
|
|
||||||
def exit_handler():
|
def exit_handler():
|
||||||
# Make DataReader stop read data when MockQueue destructor
|
# Make DataReader stop read data when MockQueue destructor
|
||||||
|
@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
|
||||||
self.read_timeout_millis = \
|
self.read_timeout_millis = \
|
||||||
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
||||||
Config.DEFAULT_READ_TIMEOUT_MS))
|
Config.DEFAULT_READ_TIMEOUT_MS))
|
||||||
|
self.python_serializer = PythonSerializer()
|
||||||
|
self.cross_lang_serializer = CrossLangSerializer()
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
pass
|
pass
|
||||||
|
@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
|
||||||
item = self.reader.read(self.read_timeout_millis)
|
item = self.reader.read(self.read_timeout_millis)
|
||||||
if item is not None:
|
if item is not None:
|
||||||
msg_data = item.body()
|
msg_data = item.body()
|
||||||
msg = pickle.loads(msg_data)
|
type_id = msg_data[:1]
|
||||||
|
if (type_id == serialization._PYTHON_TYPE_ID):
|
||||||
|
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||||
|
else:
|
||||||
|
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
||||||
self.processor.process(msg)
|
self.processor.process(msg)
|
||||||
self.stopped = True
|
self.stopped = True
|
||||||
|
|
||||||
|
|
|
@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
|
||||||
wrap initial parameters needed by a streaming queue
|
wrap initial parameters needed by a streaming queue
|
||||||
"""
|
"""
|
||||||
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
|
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
|
||||||
"io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
|
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
|
||||||
|
"([B)V")
|
||||||
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
|
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
|
||||||
"io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
|
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
|
||||||
|
"([B)[B")
|
||||||
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
|
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
|
||||||
"io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
|
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
|
||||||
|
"([B)V")
|
||||||
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
|
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
|
||||||
"io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
|
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
|
||||||
|
"([B)[B")
|
||||||
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
|
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
|
||||||
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
|
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
|
||||||
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
|
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
|
||||||
|
|
|
@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# special flag to indicate this actor not ready
|
||||||
|
_NOT_READY_FLAG_ = b" " * 4
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class JobWorker(object):
|
class JobWorker(object):
|
||||||
|
@ -66,23 +69,31 @@ class JobWorker(object):
|
||||||
type(self.stream_processor))
|
type(self.stream_processor))
|
||||||
|
|
||||||
def on_reader_message(self, buffer: bytes):
|
def on_reader_message(self, buffer: bytes):
|
||||||
"""used in direct call mode"""
|
"""Called by upstream queue writer to send data message to downstream
|
||||||
|
queue reader.
|
||||||
|
"""
|
||||||
self.reader_client.on_reader_message(buffer)
|
self.reader_client.on_reader_message(buffer)
|
||||||
|
|
||||||
def on_reader_message_sync(self, buffer: bytes):
|
def on_reader_message_sync(self, buffer: bytes):
|
||||||
"""used in direct call mode"""
|
"""Called by upstream queue writer to send control message to downstream
|
||||||
|
downstream queue reader.
|
||||||
|
"""
|
||||||
if self.reader_client is None:
|
if self.reader_client is None:
|
||||||
return b" " * 4 # special flag to indicate this actor not ready
|
return _NOT_READY_FLAG_
|
||||||
result = self.reader_client.on_reader_message_sync(buffer)
|
result = self.reader_client.on_reader_message_sync(buffer)
|
||||||
return result.to_pybytes()
|
return result.to_pybytes()
|
||||||
|
|
||||||
def on_writer_message(self, buffer: bytes):
|
def on_writer_message(self, buffer: bytes):
|
||||||
"""used in direct call mode"""
|
"""Called by downstream queue reader to send notify message to
|
||||||
|
upstream queue writer.
|
||||||
|
"""
|
||||||
self.writer_client.on_writer_message(buffer)
|
self.writer_client.on_writer_message(buffer)
|
||||||
|
|
||||||
def on_writer_message_sync(self, buffer: bytes):
|
def on_writer_message_sync(self, buffer: bytes):
|
||||||
"""used in direct call mode"""
|
"""Called by downstream queue reader to send control message to
|
||||||
|
upstream queue writer.
|
||||||
|
"""
|
||||||
if self.writer_client is None:
|
if self.writer_client is None:
|
||||||
return b" " * 4 # special flag to indicate this actor not ready
|
return _NOT_READY_FLAG_
|
||||||
result = self.writer_client.on_writer_message_sync(buffer)
|
result = self.writer_client.on_writer_message_sync(buffer)
|
||||||
return result.to_pybytes()
|
return result.to_pybytes()
|
||||||
|
|
|
@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
|
||||||
|
|
||||||
|
|
||||||
def test_load_function():
|
def test_load_function():
|
||||||
# function_bytes, module_name, class_name, function_name,
|
# function_bytes, module_name, function_name/class_name,
|
||||||
# function_interface
|
# function_interface
|
||||||
descriptor_func_bytes = gateway_client.serialize(
|
descriptor_func_bytes = gateway_client.serialize(
|
||||||
[None, __name__, MapFunc.__name__, None, "MapFunction"])
|
[None, __name__, MapFunc.__name__, "MapFunction"])
|
||||||
func = function.load_function(descriptor_func_bytes)
|
func = function.load_function(descriptor_func_bytes)
|
||||||
assert type(func) is MapFunc
|
assert type(func) is MapFunc
|
||||||
|
|
70
streaming/python/tests/test_hybrid_stream.py
Normal file
70
streaming/python/tests/test_hybrid_stream.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import json
|
||||||
|
import ray
|
||||||
|
from ray.streaming import StreamingContext
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def map_func1(x):
|
||||||
|
print("HybridStreamTest map_func1", x)
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_func1(x):
|
||||||
|
print("HybridStreamTest filter_func1", x)
|
||||||
|
return "b" not in x
|
||||||
|
|
||||||
|
|
||||||
|
def sink_func1(x):
|
||||||
|
print("HybridStreamTest sink_func1 value:", x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_stream():
|
||||||
|
subprocess.check_call(
|
||||||
|
["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
|
||||||
|
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
jar_path = os.path.join(
|
||||||
|
current_dir,
|
||||||
|
"../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
|
||||||
|
jar_path = os.path.abspath(jar_path)
|
||||||
|
print("jar_path", jar_path)
|
||||||
|
java_worker_options = json.dumps(["-classpath", jar_path])
|
||||||
|
print("java_worker_options", java_worker_options)
|
||||||
|
assert not ray.is_initialized()
|
||||||
|
ray.init(
|
||||||
|
load_code_from_local=True,
|
||||||
|
include_java=True,
|
||||||
|
java_worker_options=java_worker_options,
|
||||||
|
_internal_config=json.dumps({
|
||||||
|
"num_workers_per_process_java": 1
|
||||||
|
}))
|
||||||
|
|
||||||
|
sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
|
||||||
|
if os.path.exists(sink_file):
|
||||||
|
os.remove(sink_file)
|
||||||
|
|
||||||
|
def sink_func(x):
|
||||||
|
print("HybridStreamTest", x)
|
||||||
|
with open(sink_file, "a") as f:
|
||||||
|
f.write(str(x))
|
||||||
|
|
||||||
|
ctx = StreamingContext.Builder().build()
|
||||||
|
ctx.from_values("a", "b", "c") \
|
||||||
|
.as_java_stream() \
|
||||||
|
.map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
|
||||||
|
.filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
|
||||||
|
.as_python_stream() \
|
||||||
|
.sink(sink_func)
|
||||||
|
ctx.submit("HybridStreamTest")
|
||||||
|
import time
|
||||||
|
time.sleep(3)
|
||||||
|
ray.shutdown()
|
||||||
|
with open(sink_file, "r") as f:
|
||||||
|
result = f.read()
|
||||||
|
assert "a" in result
|
||||||
|
assert "b" not in result
|
||||||
|
assert "c" in result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_hybrid_stream()
|
13
streaming/python/tests/test_serialization.py
Normal file
13
streaming/python/tests/test_serialization.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from ray.streaming.runtime.serialization import CrossLangSerializer
|
||||||
|
from ray.streaming.message import Record, KeyRecord
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize():
|
||||||
|
serializer = CrossLangSerializer()
|
||||||
|
record = Record("value")
|
||||||
|
record.stream = "stream1"
|
||||||
|
key_record = KeyRecord("key", "value")
|
||||||
|
key_record.stream = "stream2"
|
||||||
|
assert record == serializer.deserialize(serializer.serialize(record))
|
||||||
|
assert key_record == serializer.\
|
||||||
|
deserialize(serializer.serialize(key_record))
|
31
streaming/python/tests/test_stream.py
Normal file
31
streaming/python/tests/test_stream.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import ray
|
||||||
|
from ray.streaming import StreamingContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_stream():
|
||||||
|
ray.init(load_code_from_local=True, include_java=True)
|
||||||
|
ctx = StreamingContext.Builder().build()
|
||||||
|
stream = ctx.from_values(1, 2, 3)
|
||||||
|
java_stream = stream.as_java_stream()
|
||||||
|
python_stream = java_stream.as_python_stream()
|
||||||
|
assert stream.get_id() == java_stream.get_id()
|
||||||
|
assert stream.get_id() == python_stream.get_id()
|
||||||
|
python_stream.set_parallelism(10)
|
||||||
|
assert stream.get_parallelism() == java_stream.get_parallelism()
|
||||||
|
assert stream.get_parallelism() == python_stream.get_parallelism()
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_key_data_stream():
|
||||||
|
ray.init(load_code_from_local=True, include_java=True)
|
||||||
|
ctx = StreamingContext.Builder().build()
|
||||||
|
key_stream = ctx.from_values(
|
||||||
|
"a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
|
||||||
|
java_stream = key_stream.as_java_stream()
|
||||||
|
python_stream = java_stream.as_python_stream()
|
||||||
|
assert key_stream.get_id() == java_stream.get_id()
|
||||||
|
assert key_stream.get_id() == python_stream.get_id()
|
||||||
|
python_stream.set_parallelism(10)
|
||||||
|
assert key_stream.get_parallelism() == java_stream.get_parallelism()
|
||||||
|
assert key_stream.get_parallelism() == python_stream.get_parallelism()
|
||||||
|
ray.shutdown()
|
|
@ -32,7 +32,9 @@ def test_simple_word_count():
|
||||||
|
|
||||||
def sink_func(x):
|
def sink_func(x):
|
||||||
with open(sink_file, "a") as f:
|
with open(sink_file, "a") as f:
|
||||||
f.write("{}:{},".format(x[0], x[1]))
|
line = "{}:{},".format(x[0], x[1])
|
||||||
|
print("sink_func", line)
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
ctx.from_values("a", "b", "c") \
|
ctx.from_values("a", "b", "c") \
|
||||||
.set_parallelism(1) \
|
.set_parallelism(1) \
|
||||||
|
|
|
@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
|
||||||
return reinterpret_cast<jlong>(reader_client);
|
return reinterpret_cast<jlong>(reader_client);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
|
||||||
|
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||||
|
auto *writer_client = reinterpret_cast<WriterClient *>(ptr);
|
||||||
|
writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
|
||||||
|
}
|
||||||
|
|
||||||
JNIEXPORT jbyteArray JNICALL
|
JNIEXPORT jbyteArray JNICALL
|
||||||
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
|
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
|
||||||
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue