diff --git a/python/ray/worker.py b/python/ray/worker.py
index f28ac6fbc..cec67dba6 100644
--- a/python/ray/worker.py
+++ b/python/ray/worker.py
@@ -542,6 +542,7 @@ def init(address=None,
raylet_socket_name=None,
temp_dir=None,
load_code_from_local=False,
+ java_worker_options=None,
use_pickle=True,
_internal_config=None,
lru_evict=False):
@@ -651,6 +652,7 @@ def init(address=None,
conventional location, e.g., "/tmp/ray".
load_code_from_local: Whether code should be loaded from a local
module or from the GCS.
+ java_worker_options: Overwrite the options to start Java workers.
use_pickle: Deprecated.
_internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY.
@@ -758,6 +760,7 @@ def init(address=None,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local,
+ java_worker_options=java_worker_options,
_internal_config=_internal_config,
)
# Start the Ray processes. We set shutdown_at_exit=False because we
@@ -808,6 +811,9 @@ def init(address=None,
if raylet_socket_name is not None:
raise ValueError("When connecting to an existing cluster, "
"raylet_socket_name must not be provided.")
+ if java_worker_options is not None:
+ raise ValueError("When connecting to an existing cluster, "
+ "java_worker_options must not be provided.")
if _internal_config is not None and len(_internal_config) != 0:
raise ValueError("When connecting to an existing cluster, "
"_internal_config must not be provided.")
diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel
index ccaefb3b0..8d66ebd91 100644
--- a/streaming/java/BUILD.bazel
+++ b/streaming/java/BUILD.bazel
@@ -39,6 +39,7 @@ define_java_module(
":io_ray_ray_streaming-state",
":io_ray_ray_streaming-api",
"@ray_streaming_maven//:com_google_guava_guava",
+ "@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
"@ray_streaming_maven//:org_testng_testng",
@@ -46,7 +47,12 @@ define_java_module(
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_streaming-state",
+ "//java:io_ray_ray_api",
+ "//java:io_ray_ray_runtime",
+ "@ray_streaming_maven//:com_google_code_findbugs_jsr305",
+ "@ray_streaming_maven//:com_google_code_gson_gson",
"@ray_streaming_maven//:com_google_guava_guava",
+ "@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
],
@@ -129,8 +135,9 @@ define_java_module(
":io_ray_ray_streaming-api",
":io_ray_ray_streaming-runtime",
"@ray_streaming_maven//:com_google_guava_guava",
+ "@ray_streaming_maven//:com_google_code_findbugs_jsr305",
+ "@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:de_ruedigermoeller_fst",
- "@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_aeonbits_owner_owner",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
@@ -146,10 +153,12 @@ define_java_module(
"//java:io_ray_ray_api",
"//java:io_ray_ray_runtime",
"@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java",
+ "@ray_streaming_maven//:com_google_code_findbugs_jsr305",
"@ray_streaming_maven//:com_google_guava_guava",
"@ray_streaming_maven//:com_google_protobuf_protobuf_java",
"@ray_streaming_maven//:de_ruedigermoeller_fst",
"@ray_streaming_maven//:org_aeonbits_owner_owner",
+ "@ray_streaming_maven//:org_apache_commons_commons_lang3",
"@ray_streaming_maven//:org_msgpack_msgpack_core",
"@ray_streaming_maven//:org_slf4j_slf4j_api",
"@ray_streaming_maven//:org_slf4j_slf4j_log4j12",
diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl
index 40327336d..998d88434 100644
--- a/streaming/java/dependencies.bzl
+++ b/streaming/java/dependencies.bzl
@@ -6,8 +6,11 @@ def gen_streaming_java_deps():
artifacts = [
"com.beust:jcommander:1.72",
"com.google.guava:guava:27.0.1-jre",
+ "com.google.code.findbugs:jsr305:3.0.2",
+ "com.google.code.gson:gson:2.8.5",
"com.github.davidmoten:flatbuffers-java:1.9.0.1",
"com.google.protobuf:protobuf-java:3.8.0",
+ "org.apache.commons:commons-lang3:3.4",
"de.ruedigermoeller:fst:2.57",
"org.aeonbits.owner:owner:1.0.10",
"org.slf4j:slf4j-api:1.7.12",
@@ -19,10 +22,9 @@ def gen_streaming_java_deps():
"org.apache.commons:commons-lang3:3.3.2",
"org.msgpack:msgpack-core:0.8.20",
"org.testng:testng:6.9.10",
- "org.mockito:mockito-all:1.10.19",
- "org.powermock:powermock-module-testng:1.6.6",
- "org.powermock:powermock-api-mockito:1.6.6",
- "org.projectlombok:lombok:1.16.20",
+ "org.mockito:mockito-all:1.10.19",
+ "org.powermock:powermock-module-testng:1.6.6",
+ "org.powermock:powermock-api-mockito:1.6.6",
],
repositories = [
"https://repo1.maven.org/maven2/",
diff --git a/streaming/java/streaming-api/pom.xml b/streaming/java/streaming-api/pom.xml
index 253f7a3b4..4e100fefd 100644
--- a/streaming/java/streaming-api/pom.xml
+++ b/streaming/java/streaming-api/pom.xml
@@ -22,16 +22,36 @@
ray-api${project.version}
+
+ io.ray
+ ray-runtime
+ ${project.version}
+ org.raystreaming-state${project.version}
+ com.google.code.findbugs
+ jsr305
+ 3.0.2
+
+
+ com.google.code.gson
+ gson
+ 2.8.5
+
+com.google.guavaguava27.0.1-jre
+
+ org.apache.commons
+ commons-lang3
+ 3.4
+org.slf4jslf4j-api
diff --git a/streaming/java/streaming-api/pom_template.xml b/streaming/java/streaming-api/pom_template.xml
index 7c7171cdc..9b94fb278 100644
--- a/streaming/java/streaming-api/pom_template.xml
+++ b/streaming/java/streaming-api/pom_template.xml
@@ -22,6 +22,11 @@
ray-api${project.version}
+
+ io.ray
+ ray-runtime
+ ${project.version}
+ org.raystreaming-state
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java
new file mode 100644
index 000000000..0fe98798c
--- /dev/null
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/ClusterStarter.java
@@ -0,0 +1,129 @@
+package io.ray.streaming.api.context;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.gson.Gson;
+import io.ray.api.Ray;
+import io.ray.runtime.config.RayConfig;
+import io.ray.runtime.util.NetworkUtil;
+import java.io.File;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class ClusterStarter {
+ private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
+ private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
+ private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
+
+ static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
+ Preconditions.checkArgument(Ray.internal() == null);
+ RayConfig.reset();
+ if (!isLocal) {
+ System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
+ System.setProperty("ray.run-mode", "CLUSTER");
+ } else {
+ System.clearProperty("ray.raylet.config.num_workers_per_process_java");
+ System.setProperty("ray.run-mode", "SINGLE_PROCESS");
+ }
+
+ if (!isCrossLanguage) {
+ Ray.init();
+ return;
+ }
+
+ // Delete existing socket files.
+ for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
+ File file = new File(socket);
+ if (file.exists()) {
+ LOG.info("Delete existing socket file {}", file);
+ file.delete();
+ }
+ }
+
+ String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
+
+ // jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
+ // Since mvn test classes contains `test` in path and bazel test classes is located at a jar
+ // with `test` included in the name, we can check classpath `test` to filter out test classes.
+ String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
+ .filter(s -> !s.contains(" ") && s.contains("test"))
+ .collect(Collectors.joining(":"));
+ String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
+ Map config = new HashMap<>(RayConfig.create().rayletConfigParameters);
+ config.put("num_workers_per_process_java", "1");
+ // Start ray cluster.
+ List startCommand = ImmutableList.of(
+ "ray",
+ "start",
+ "--head",
+ "--redis-port=6379",
+ String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
+ String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
+ String.format("--node-manager-port=%s", nodeManagerPort),
+ "--load-code-from-local",
+ "--include-java",
+ "--java-worker-options=" + workerOptions,
+ "--internal-config=" + new Gson().toJson(config)
+ );
+ if (!executeCommand(startCommand, 10)) {
+ throw new RuntimeException("Couldn't start ray cluster.");
+ }
+
+ // Connect to the cluster.
+ System.setProperty("ray.redis.address", "127.0.0.1:6379");
+ System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
+ System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
+ System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
+ Ray.init();
+ }
+
+ public static synchronized void stopCluster(boolean isCrossLanguage) {
+ // Disconnect to the cluster.
+ Ray.shutdown();
+ System.clearProperty("ray.redis.address");
+ System.clearProperty("ray.object-store.socket-name");
+ System.clearProperty("ray.raylet.socket-name");
+ System.clearProperty("ray.raylet.node-manager-port");
+ System.clearProperty("ray.raylet.config.num_workers_per_process_java");
+ System.clearProperty("ray.run-mode");
+
+ if (isCrossLanguage) {
+ // Stop ray cluster.
+ final List stopCommand = ImmutableList.of(
+ "ray",
+ "stop"
+ );
+ if (!executeCommand(stopCommand, 10)) {
+ throw new RuntimeException("Couldn't stop ray cluster");
+ }
+ }
+ }
+
+ /**
+ * Execute an external command.
+ *
+ * @return Whether the command succeeded.
+ */
+ private static boolean executeCommand(List command, int waitTimeoutSeconds) {
+ LOG.info("Executing command: {}", String.join(" ", command));
+ try {
+ ProcessBuilder processBuilder = new ProcessBuilder(command)
+ .redirectOutput(ProcessBuilder.Redirect.INHERIT)
+ .redirectError(ProcessBuilder.Redirect.INHERIT);
+ Process process = processBuilder.start();
+ boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
+ if (!exit) {
+ process.destroyForcibly();
+ }
+ return process.exitValue() == 0;
+ } catch (Exception e) {
+ throw new RuntimeException("Error executing command " + String.join(" ", command), e);
+ }
+ }
+}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java
index edf2fcd50..5f1ab4d4e 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/context/StreamingContext.java
@@ -1,10 +1,12 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
+import io.ray.api.Ray;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
import io.ray.streaming.schedule.JobScheduler;
+import io.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -13,11 +15,14 @@ import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.atomic.AtomicInteger;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Encapsulate the context information of a streaming Job.
*/
public class StreamingContext implements Serializable {
+ private static final Logger LOG = LoggerFactory.getLogger(StreamingContext.class);
private transient AtomicInteger idGenerator;
@@ -54,6 +59,20 @@ public class StreamingContext implements Serializable {
this.jobGraph = jobGraphBuilder.build();
jobGraph.printJobGraph();
+ if (Ray.internal() == null) {
+ if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
+ Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
+ ClusterStarter.startCluster(false, true);
+ LOG.info("Created local cluster for job {}.", jobName);
+ } else {
+ ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
+ LOG.info("Created multi process cluster for job {}.", jobName);
+ }
+ Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
+ } else {
+ LOG.info("Reuse existing cluster.");
+ }
+
ServiceLoader serviceLoader = ServiceLoader.load(JobScheduler.class);
Iterator iterator = serviceLoader.iterator();
Preconditions.checkArgument(iterator.hasNext(),
@@ -77,4 +96,10 @@ public class StreamingContext implements Serializable {
public void withConfig(Map jobConfig) {
this.jobConfig = jobConfig;
}
+
+ public void stop() {
+ if (Ray.internal() != null) {
+ ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
+ }
+ }
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java
index b3b43fd6c..fe0a3af1b 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStream.java
@@ -1,6 +1,7 @@
package io.ray.streaming.api.stream;
+import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FilterFunction;
import io.ray.streaming.api.function.impl.FlatMapFunction;
@@ -15,24 +16,44 @@ import io.ray.streaming.operator.impl.FlatMapOperator;
import io.ray.streaming.operator.impl.KeyByOperator;
import io.ray.streaming.operator.impl.MapOperator;
import io.ray.streaming.operator.impl.SinkOperator;
+import io.ray.streaming.python.stream.PythonDataStream;
/**
* Represents a stream of data.
- *
- * This class defines all the streaming operations.
+ *
This class defines all the streaming operations.
*
* @param Type of data in the stream.
*/
-public class DataStream extends Stream {
+public class DataStream extends Stream, T> {
public DataStream(StreamingContext streamingContext, StreamOperator streamOperator) {
super(streamingContext, streamOperator);
}
- public DataStream(DataStream input, StreamOperator streamOperator) {
+ public DataStream(StreamingContext streamingContext,
+ StreamOperator streamOperator,
+ Partition partition) {
+ super(streamingContext, streamOperator, partition);
+ }
+
+ public DataStream(DataStream input, StreamOperator streamOperator) {
super(input, streamOperator);
}
+ public DataStream(DataStream input,
+ StreamOperator streamOperator,
+ Partition partition) {
+ super(input, streamOperator, partition);
+ }
+
+ /**
+ * Create a java stream that reference passed python stream.
+ * Changes in new stream will be reflected in referenced stream and vice versa
+ */
+ public DataStream(PythonDataStream referencedStream) {
+ super(referencedStream);
+ }
+
/**
* Apply a map function to this stream.
*
@@ -41,7 +62,7 @@ public class DataStream extends Stream {
* @return A new DataStream.
*/
public DataStream map(MapFunction mapFunction) {
- return new DataStream<>(this, new MapOperator(mapFunction));
+ return new DataStream<>(this, new MapOperator<>(mapFunction));
}
/**
@@ -52,11 +73,11 @@ public class DataStream extends Stream {
* @return A new DataStream
*/
public DataStream flatMap(FlatMapFunction flatMapFunction) {
- return new DataStream(this, new FlatMapOperator(flatMapFunction));
+ return new DataStream<>(this, new FlatMapOperator<>(flatMapFunction));
}
public DataStream filter(FilterFunction filterFunction) {
- return new DataStream(this, new FilterOperator(filterFunction));
+ return new DataStream<>(this, new FilterOperator<>(filterFunction));
}
/**
@@ -66,7 +87,7 @@ public class DataStream extends Stream {
* @return A new UnionStream.
*/
public UnionStream union(DataStream other) {
- return new UnionStream(this, null, other);
+ return new UnionStream<>(this, null, other);
}
/**
@@ -93,7 +114,7 @@ public class DataStream extends Stream {
* @return A new StreamSink.
*/
public DataStreamSink sink(SinkFunction sinkFunction) {
- return new DataStreamSink<>(this, new SinkOperator(sinkFunction));
+ return new DataStreamSink<>(this, new SinkOperator<>(sinkFunction));
}
/**
@@ -104,7 +125,8 @@ public class DataStream extends Stream {
* @return A new KeyDataStream.
*/
public KeyDataStream keyBy(KeyFunction keyFunction) {
- return new KeyDataStream<>(this, new KeyByOperator(keyFunction));
+ checkPartitionCall();
+ return new KeyDataStream<>(this, new KeyByOperator<>(keyFunction));
}
/**
@@ -113,8 +135,8 @@ public class DataStream extends Stream {
* @return This stream.
*/
public DataStream broadcast() {
- this.partition = new BroadcastPartition<>();
- return this;
+ checkPartitionCall();
+ return setPartition(new BroadcastPartition<>());
}
/**
@@ -124,19 +146,32 @@ public class DataStream extends Stream {
* @return This stream.
*/
public DataStream partitionBy(Partition partition) {
- this.partition = partition;
- return this;
+ checkPartitionCall();
+ return setPartition(partition);
}
/**
- * Set parallelism to current transformation.
- *
- * @param parallelism The parallelism to set.
- * @return This stream.
+ * If parent stream is a python stream, we can't call partition related methods
+ * in the java stream.
*/
- public DataStream setParallelism(int parallelism) {
- this.parallelism = parallelism;
- return this;
+ private void checkPartitionCall() {
+ if (getInputStream() != null && getInputStream().getLanguage() == Language.PYTHON) {
+ throw new RuntimeException("Partition related methods can't be called on a " +
+ "java stream if parent stream is a python stream.");
+ }
}
+ /**
+ * Convert this stream as a python stream.
+ * The converted stream and this stream are the same logical stream, which has same stream id.
+ * Changes in converted stream will be reflected in this stream and vice versa.
+ */
+ public PythonDataStream asPythonStream() {
+ return new PythonDataStream(this);
+ }
+
+ @Override
+ public Language getLanguage() {
+ return Language.JAVA;
+ }
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java
index e19d1b027..e58bb420b 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSink.java
@@ -1,5 +1,6 @@
package io.ray.streaming.api.stream;
+import io.ray.streaming.api.Language;
import io.ray.streaming.operator.impl.SinkOperator;
/**
@@ -9,13 +10,13 @@ import io.ray.streaming.operator.impl.SinkOperator;
*/
public class DataStreamSink extends StreamSink {
- public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
+ public DataStreamSink(DataStream input, SinkOperator sinkOperator) {
super(input, sinkOperator);
- this.streamingContext.addSink(this);
+ getStreamingContext().addSink(this);
}
- public DataStreamSink setParallelism(int parallelism) {
- this.parallelism = parallelism;
- return this;
+ @Override
+ public Language getLanguage() {
+ return Language.JAVA;
}
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java
index 9f3b353ca..87ccb5eaf 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/DataStreamSource.java
@@ -14,27 +14,26 @@ import java.util.Collection;
*/
public class DataStreamSource extends DataStream implements StreamSource {
- public DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) {
- super(streamingContext, new SourceOperator<>(sourceFunction));
- super.partition = new RoundRobinPartition<>();
+ private DataStreamSource(StreamingContext streamingContext, SourceFunction sourceFunction) {
+ super(streamingContext, new SourceOperator<>(sourceFunction), new RoundRobinPartition<>());
+ }
+
+ public static DataStreamSource fromSource(
+ StreamingContext context, SourceFunction sourceFunction) {
+ return new DataStreamSource<>(context, sourceFunction);
}
/**
* Build a DataStreamSource source from a collection.
*
* @param context Stream context.
- * @param values A collection of values.
- * @param The type of source data.
+ * @param values A collection of values.
+ * @param The type of source data.
* @return A DataStreamSource.
*/
- public static DataStreamSource buildSource(
+ public static DataStreamSource fromCollection(
StreamingContext context, Collection values) {
- return new DataStreamSource(context, new CollectionSourceFunction(values));
+ return new DataStreamSource<>(context, new CollectionSourceFunction<>(values));
}
- @Override
- public DataStreamSource setParallelism(int parallelism) {
- this.parallelism = parallelism;
- return this;
- }
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java
index ad48f2efa..68708b9e9 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/KeyDataStream.java
@@ -2,9 +2,12 @@ package io.ray.streaming.api.stream;
import io.ray.streaming.api.function.impl.AggregateFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
+import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.KeyPartition;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.impl.ReduceOperator;
+import io.ray.streaming.python.stream.PythonDataStream;
+import io.ray.streaming.python.stream.PythonKeyDataStream;
/**
* Represents a DataStream returned by a key-by operation.
@@ -12,11 +15,19 @@ import io.ray.streaming.operator.impl.ReduceOperator;
* @param Type of the key.
* @param Type of the data.
*/
+@SuppressWarnings("unchecked")
public class KeyDataStream extends DataStream {
public KeyDataStream(DataStream input, StreamOperator streamOperator) {
- super(input, streamOperator);
- this.partition = new KeyPartition();
+ super(input, streamOperator, (Partition) new KeyPartition());
+ }
+
+ /**
+ * Create a java stream that reference passed python stream.
+ * Changes in new stream will be reflected in referenced stream and vice versa
+ */
+ public KeyDataStream(PythonDataStream referencedStream) {
+ super(referencedStream);
}
/**
@@ -41,8 +52,13 @@ public class KeyDataStream extends DataStream {
return new DataStream<>(this, null);
}
- public KeyDataStream setParallelism(int parallelism) {
- this.parallelism = parallelism;
- return this;
+ /**
+ * Convert this stream as a python stream.
+ * The converted stream and this stream are the same logical stream, which has same stream id.
+ * Changes in converted stream will be reflected in this stream and vice versa.
+ */
+ public PythonKeyDataStream asPythonStream() {
+ return new PythonKeyDataStream(this);
}
+
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java
index 791124c41..4c74780cd 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/Stream.java
@@ -1,58 +1,99 @@
package io.ray.streaming.api.stream;
+import com.google.common.base.Preconditions;
+import io.ray.streaming.api.Language;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
+import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
-import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
-import io.ray.streaming.python.stream.PythonStream;
import java.io.Serializable;
/**
* Abstract base class of all stream types.
*
+ * @param Type of stream class
* @param Type of the data in the stream.
*/
-public abstract class Stream implements Serializable {
- protected int id;
- protected int parallelism = 1;
- protected StreamOperator operator;
- protected Stream inputStream;
- protected StreamingContext streamingContext;
- protected Partition partition;
+public abstract class Stream, T>
+ implements Serializable {
+ private final int id;
+ private final StreamingContext streamingContext;
+ private final Stream inputStream;
+ private final StreamOperator operator;
+ private int parallelism = 1;
+ private Partition partition;
+ private Stream originalStream;
- @SuppressWarnings("unchecked")
public Stream(StreamingContext streamingContext, StreamOperator streamOperator) {
+ this(streamingContext, null, streamOperator,
+ selectPartition(streamOperator));
+ }
+
+ public Stream(StreamingContext streamingContext,
+ StreamOperator streamOperator,
+ Partition partition) {
+ this(streamingContext, null, streamOperator, partition);
+ }
+
+ public Stream(Stream inputStream, StreamOperator streamOperator) {
+ this(inputStream.getStreamingContext(), inputStream, streamOperator,
+ selectPartition(streamOperator));
+ }
+
+ public Stream(Stream inputStream, StreamOperator streamOperator, Partition partition) {
+ this(inputStream.getStreamingContext(), inputStream, streamOperator, partition);
+ }
+
+ protected Stream(StreamingContext streamingContext,
+ Stream inputStream,
+ StreamOperator streamOperator,
+ Partition partition) {
this.streamingContext = streamingContext;
+ this.inputStream = inputStream;
this.operator = streamOperator;
+ this.partition = partition;
this.id = streamingContext.generateId();
- if (streamOperator instanceof PythonOperator) {
- this.partition = PythonPartition.RoundRobinPartition;
- } else {
- this.partition = new RoundRobinPartition<>();
+ if (inputStream != null) {
+ this.parallelism = inputStream.getParallelism();
}
}
- public Stream(Stream inputStream, StreamOperator streamOperator) {
- this.inputStream = inputStream;
- this.parallelism = inputStream.getParallelism();
- this.streamingContext = this.inputStream.getStreamingContext();
- this.operator = streamOperator;
- this.id = streamingContext.generateId();
- this.partition = selectPartition();
+ /**
+ * Create a proxy stream of original stream.
+ * Changes in new stream will be reflected in original stream and vice versa
+ */
+ protected Stream(Stream originalStream) {
+ this.originalStream = originalStream;
+ this.id = originalStream.getId();
+ this.streamingContext = originalStream.getStreamingContext();
+ this.inputStream = originalStream.getInputStream();
+ this.operator = originalStream.getOperator();
}
@SuppressWarnings("unchecked")
- private Partition selectPartition() {
- if (inputStream instanceof PythonStream) {
- return PythonPartition.RoundRobinPartition;
- } else {
- return new RoundRobinPartition<>();
+ private static Partition selectPartition(Operator operator) {
+ switch (operator.getLanguage()) {
+ case PYTHON:
+ return (Partition) PythonPartition.RoundRobinPartition;
+ case JAVA:
+ return new RoundRobinPartition<>();
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported language " + operator.getLanguage());
}
}
- public Stream getInputStream() {
+ public int getId() {
+ return id;
+ }
+
+ public StreamingContext getStreamingContext() {
+ return streamingContext;
+ }
+
+ public Stream getInputStream() {
return inputStream;
}
@@ -60,32 +101,47 @@ public abstract class Stream implements Serializable {
return operator;
}
- public void setOperator(StreamOperator operator) {
- this.operator = operator;
- }
-
- public StreamingContext getStreamingContext() {
- return streamingContext;
+ @SuppressWarnings("unchecked")
+ private S self() {
+ return (S) this;
}
public int getParallelism() {
- return parallelism;
+ return originalStream != null ? originalStream.getParallelism() : parallelism;
}
- public Stream setParallelism(int parallelism) {
- this.parallelism = parallelism;
- return this;
- }
-
- public int getId() {
- return id;
+ public S setParallelism(int parallelism) {
+ if (originalStream != null) {
+ originalStream.setParallelism(parallelism);
+ } else {
+ this.parallelism = parallelism;
+ }
+ return self();
}
+ @SuppressWarnings("unchecked")
public Partition getPartition() {
- return partition;
+ return originalStream != null ? originalStream.getPartition() : partition;
}
- public void setPartition(Partition partition) {
- this.partition = partition;
+ @SuppressWarnings("unchecked")
+ protected S setPartition(Partition partition) {
+ if (originalStream != null) {
+ originalStream.setPartition(partition);
+ } else {
+ this.partition = partition;
+ }
+ return self();
}
+
+ public boolean isProxyStream() {
+ return originalStream != null;
+ }
+
+ public Stream getOriginalStream() {
+ Preconditions.checkArgument(isProxyStream());
+ return originalStream;
+ }
+
+ public abstract Language getLanguage();
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java
index 944b93eae..f03b1baa4 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/StreamSink.java
@@ -7,8 +7,8 @@ import io.ray.streaming.operator.StreamOperator;
*
* @param Type of the input data of this sink.
*/
-public class StreamSink extends Stream {
- public StreamSink(Stream inputStream, StreamOperator streamOperator) {
+public abstract class StreamSink extends Stream, T> {
+ public StreamSink(Stream inputStream, StreamOperator streamOperator) {
super(inputStream, streamOperator);
}
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java
index ed7434c5c..6dd559ce7 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/stream/UnionStream.java
@@ -11,15 +11,15 @@ import java.util.List;
*/
public class UnionStream extends DataStream {
- private List unionStreams;
+ private List> unionStreams;
- public UnionStream(DataStream input, StreamOperator streamOperator, DataStream other) {
+ public UnionStream(DataStream input, StreamOperator streamOperator, DataStream other) {
super(input, streamOperator);
this.unionStreams = new ArrayList<>();
this.unionStreams.add(other);
}
- public List getUnionStreams() {
+ public List> getUnionStreams() {
return unionStreams;
}
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java
index 675cad1ea..e670e5ea3 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraph.java
@@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
+import io.ray.streaming.api.Language;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@@ -97,4 +98,14 @@ public class JobGraph implements Serializable {
}
}
+ public boolean isCrossLanguageGraph() {
+ Language language = jobVertexList.get(0).getLanguage();
+ for (JobVertex jobVertex : jobVertexList) {
+ if (jobVertex.getLanguage() != language) {
+ return true;
+ }
+ }
+ return false;
+ }
+
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java
index b8d5af9a4..d0f6a7dc3 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/jobgraph/JobGraphBuilder.java
@@ -1,5 +1,6 @@
package io.ray.streaming.jobgraph;
+import com.google.common.base.Preconditions;
import io.ray.streaming.api.stream.DataStream;
import io.ray.streaming.api.stream.Stream;
import io.ray.streaming.api.stream.StreamSink;
@@ -10,8 +11,11 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class JobGraphBuilder {
+ private static final Logger LOG = LoggerFactory.getLogger(JobGraphBuilder.class);
private JobGraph jobGraph;
@@ -41,12 +45,19 @@ public class JobGraphBuilder {
}
private void processStream(Stream stream) {
+ while (stream.isProxyStream()) {
+ // Proxy stream and original stream are the same logical stream, both refer to the
+ // same data flow transformation. We should skip proxy stream to avoid applying same
+ // transformation multiple times.
+ LOG.debug("Skip proxy stream {} of id {}", stream, stream.getId());
+ stream = stream.getOriginalStream();
+ }
+ StreamOperator streamOperator = stream.getOperator();
+ Preconditions.checkArgument(stream.getLanguage() == streamOperator.getLanguage(),
+ "Reference stream should be skipped.");
int vertexId = stream.getId();
int parallelism = stream.getParallelism();
-
- StreamOperator streamOperator = stream.getOperator();
- JobVertex jobVertex = null;
-
+ JobVertex jobVertex;
if (stream instanceof StreamSink) {
jobVertex = new JobVertex(vertexId, parallelism, VertexType.SINK, streamOperator);
Stream parentStream = stream.getInputStream();
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java
index d91b4cbd5..c99ec9959 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java
@@ -1,6 +1,8 @@
package io.ray.streaming.message;
+import java.util.Objects;
+
public class KeyRecord extends Record {
private K key;
@@ -17,4 +19,24 @@ public class KeyRecord extends Record {
public void setKey(K key) {
this.key = key;
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ KeyRecord, ?> keyRecord = (KeyRecord, ?>) o;
+ return Objects.equals(key, keyRecord.key);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), key);
+ }
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java
deleted file mode 100644
index a943dcb9d..000000000
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Message.java
+++ /dev/null
@@ -1,64 +0,0 @@
-package io.ray.streaming.message;
-
-import com.google.common.collect.Lists;
-import java.io.Serializable;
-import java.util.List;
-
-public class Message implements Serializable {
-
- private int taskId;
- private long batchId;
- private String stream;
- private List recordList;
-
- public Message(int taskId, long batchId, String stream, List recordList) {
- this.taskId = taskId;
- this.batchId = batchId;
- this.stream = stream;
- this.recordList = recordList;
- }
-
- public Message(int taskId, long batchId, String stream, Record record) {
- this.taskId = taskId;
- this.batchId = batchId;
- this.stream = stream;
- this.recordList = Lists.newArrayList(record);
- }
-
- public int getTaskId() {
- return taskId;
- }
-
- public void setTaskId(int taskId) {
- this.taskId = taskId;
- }
-
- public long getBatchId() {
- return batchId;
- }
-
- public void setBatchId(long batchId) {
- this.batchId = batchId;
- }
-
- public String getStream() {
- return stream;
- }
-
- public void setStream(String stream) {
- this.stream = stream;
- }
-
- public List getRecordList() {
- return recordList;
- }
-
- public void setRecordList(List recordList) {
- this.recordList = recordList;
- }
-
- public Record getRecord(int index) {
- return recordList.get(0);
- }
-
-}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java
index 8d0ca368b..c86e47645 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/Record.java
@@ -1,6 +1,7 @@
package io.ray.streaming.message;
import java.io.Serializable;
+import java.util.Objects;
public class Record implements Serializable {
@@ -27,6 +28,24 @@ public class Record implements Serializable {
this.stream = stream;
}
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Record> record = (Record>) o;
+ return Objects.equals(stream, record.stream) &&
+ Objects.equals(value, record.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(stream, value);
+ }
+
@Override
public String toString() {
return value.toString();
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java
index 31c82b43e..21533de70 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonFunction.java
@@ -1,6 +1,8 @@
package io.ray.streaming.python;
+import com.google.common.base.Preconditions;
import io.ray.streaming.api.function.Function;
+import org.apache.commons.lang3.StringUtils;
/**
* Represents a user defined python function.
@@ -14,9 +16,8 @@ import io.ray.streaming.api.function.Function;
*
*
If the python data stream api is invoked from python, `function` will be not null.
*
If the python data stream api is invoked from java, `moduleName` and
- * `className`/`functionName` will be not null.
+ * `functionName` will be not null.
*
- * TODO serialize to bytes using protobuf
*/
public class PythonFunction implements Function {
public enum FunctionInterface {
@@ -38,23 +39,43 @@ public class PythonFunction implements Function {
}
}
- private byte[] function;
- private String moduleName;
- private String className;
- private String functionName;
+ // null if this function is constructed from moduleName/functionName.
+ private final byte[] function;
+ // null if this function is constructed from serialized python function.
+ private final String moduleName;
+ // null if this function is constructed from serialized python function.
+ private final String functionName;
/**
* FunctionInterface can be used to validate python function,
* and look up operator class from FunctionInterface.
*/
private String functionInterface;
- private PythonFunction(byte[] function,
- String moduleName,
- String className,
- String functionName) {
+ /**
+ * Create a {@link PythonFunction} from a serialized streaming python function.
+ *
+ * @param function serialized streaming python function from python driver.
+ */
+ public PythonFunction(byte[] function) {
+ Preconditions.checkNotNull(function);
this.function = function;
+ this.moduleName = null;
+ this.functionName = null;
+ }
+
+ /**
+ * Create a {@link PythonFunction} from a moduleName and streaming function name.
+ *
+ * @param moduleName module name of streaming function.
+ * @param functionName function name of streaming function. {@code functionName} is the name
+ * of a python function, or class name of subclass of `ray.streaming.function.`
+ */
+ public PythonFunction(String moduleName,
+ String functionName) {
+ Preconditions.checkArgument(StringUtils.isNotBlank(moduleName));
+ Preconditions.checkArgument(StringUtils.isNotBlank(functionName));
+ this.function = null;
this.moduleName = moduleName;
- this.className = className;
this.functionName = functionName;
}
@@ -70,10 +91,6 @@ public class PythonFunction implements Function {
return moduleName;
}
- public String getClassName() {
- return className;
- }
-
public String getFunctionName() {
return functionName;
}
@@ -82,34 +99,4 @@ public class PythonFunction implements Function {
return functionInterface;
}
- /**
- * Create a {@link PythonFunction} using python serialized function
- *
- * @param function serialized python function sent from python driver
- */
- public static PythonFunction fromFunction(byte[] function) {
- return new PythonFunction(function, null, null, null);
- }
-
- /**
- * Create a {@link PythonFunction} using moduleName and
- * className.
- *
- * @param moduleName python module name
- * @param className python class name
- */
- public static PythonFunction fromClassName(String moduleName, String className) {
- return new PythonFunction(null, moduleName, className, null);
- }
-
- /**
- * Create a {@link PythonFunction} using moduleName and
- * functionName.
- *
- * @param moduleName python module name
- * @param functionName python function name
- */
- public static PythonFunction fromFunctionName(String moduleName, String functionName) {
- return new PythonFunction(null, moduleName, null, functionName);
- }
}
diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java
index c4548031b..6d8de051f 100644
--- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java
+++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/python/PythonPartition.java
@@ -1,6 +1,8 @@
package io.ray.streaming.python;
+import com.google.common.base.Preconditions;
import io.ray.streaming.api.partition.Partition;
+import org.apache.commons.lang3.StringUtils;
/**
* Represents a python partition function.
@@ -13,28 +15,33 @@ import io.ray.streaming.api.partition.Partition;
* If this object is constructed from moduleName and className/functionName,
* python worker will use `importlib` to load python partition function.
*
- * TODO serialize to bytes using protobuf
*/
-public class PythonPartition implements Partition {
+public class PythonPartition implements Partition
+
+ com.google.code.findbugs
+ jsr305
+ 3.0.2
+com.google.guavaguava
@@ -56,6 +61,11 @@
owner1.0.10
+
+ org.apache.commons
+ commons-lang3
+ 3.4
+org.mockitomockito-all
@@ -71,11 +81,6 @@
powermock-api-mockito1.6.6
-
- org.powermock
- powermock-core
- 1.6.6
-org.powermockpowermock-module-testng
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java
index 7939e69b1..566dd15b8 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java
@@ -1,9 +1,14 @@
package io.ray.streaming.runtime.core.collector;
-import io.ray.runtime.serializer.Serializer;
+import io.ray.api.BaseActor;
+import io.ray.api.RayPyActor;
+import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.message.Record;
+import io.ray.streaming.runtime.serialization.CrossLangSerializer;
+import io.ray.streaming.runtime.serialization.JavaSerializer;
+import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.ChannelID;
import io.ray.streaming.runtime.transfer.DataWriter;
import java.nio.ByteBuffer;
@@ -14,15 +19,24 @@ import org.slf4j.LoggerFactory;
public class OutputCollector implements Collector {
private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class);
- private Partition partition;
- private DataWriter writer;
- private ChannelID[] outputQueues;
+ private final DataWriter writer;
+ private final ChannelID[] outputQueues;
+ private final Collection targetActors;
+ private final Language[] targetLanguages;
+ private final Partition partition;
+ private final Serializer javaSerializer = new JavaSerializer();
+ private final Serializer crossLangSerializer = new CrossLangSerializer();
- public OutputCollector(Collection outputQueueIds,
- DataWriter writer,
+ public OutputCollector(DataWriter writer,
+ Collection outputQueueIds,
+ Collection targetActors,
Partition partition) {
- this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
this.writer = writer;
+ this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new);
+ this.targetActors = targetActors;
+ this.targetLanguages = targetActors.stream()
+ .map(actor -> actor instanceof RayPyActor ? Language.PYTHON : Language.JAVA)
+ .toArray(Language[]::new);
this.partition = partition;
LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.",
outputQueueIds, this.partition);
@@ -31,9 +45,32 @@ public class OutputCollector implements Collector {
@Override
public void collect(Record record) {
int[] partitions = this.partition.partition(record, outputQueues.length);
- ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record).getLeft());
+ ByteBuffer javaBuffer = null;
+ ByteBuffer crossLangBuffer = null;
for (int partition : partitions) {
- writer.write(outputQueues[partition], msgBuffer);
+ if (targetLanguages[partition] == Language.JAVA) {
+ // avoid repeated serialization
+ if (javaBuffer == null) {
+ byte[] bytes = javaSerializer.serialize(record);
+ javaBuffer = ByteBuffer.allocate(1 + bytes.length);
+ javaBuffer.put(Serializer.JAVA_TYPE_ID);
+ // TODO(chaokunyang) remove copy
+ javaBuffer.put(bytes);
+ javaBuffer.flip();
+ }
+ writer.write(outputQueues[partition], javaBuffer.duplicate());
+ } else {
+ // avoid repeated serialization
+ if (crossLangBuffer == null) {
+ byte[] bytes = crossLangSerializer.serialize(record);
+ crossLangBuffer = ByteBuffer.allocate(1 + bytes.length);
+ crossLangBuffer.put(Serializer.CROSS_LANG_TYPE_ID);
+ // TODO(chaokunyang) remove copy
+ crossLangBuffer.put(bytes);
+ crossLangBuffer.flip();
+ }
+ writer.write(outputQueues[partition], crossLangBuffer.duplicate());
+ }
}
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java
index d90e02463..b79926490 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java
@@ -12,6 +12,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionTask;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.generated.Streaming;
+import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays;
public class GraphPbBuilder {
@@ -74,11 +75,10 @@ public class GraphPbBuilder {
private byte[] serializeFunction(Function function) {
if (function instanceof PythonFunction) {
PythonFunction pyFunc = (PythonFunction) function;
- // function_bytes, module_name, class_name, function_name, function_interface
+ // function_bytes, module_name, function_name, function_interface
return serializer.serialize(Arrays.asList(
pyFunc.getFunction(), pyFunc.getModuleName(),
- pyFunc.getClassName(), pyFunc.getFunctionName(),
- pyFunc.getFunctionInterface()
+ pyFunc.getFunctionName(), pyFunc.getFunctionInterface()
));
} else {
return new byte[0];
@@ -88,10 +88,10 @@ public class GraphPbBuilder {
private byte[] serializePartition(Partition partition) {
if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition;
- // partition_bytes, module_name, class_name, function_name
+ // partition_bytes, module_name, function_name
return serializer.serialize(Arrays.asList(
pythonPartition.getPartition(), pythonPartition.getModuleName(),
- pythonPartition.getClassName(), pythonPartition.getFunctionName()
+ pythonPartition.getFunctionName()
));
} else {
return new byte[0];
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java
index 80a487db0..826f1c935 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/PythonGateway.java
@@ -1,16 +1,21 @@
package io.ray.streaming.runtime.python;
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Primitives;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.python.PythonFunction;
import io.ray.streaming.python.PythonPartition;
import io.ray.streaming.python.stream.PythonStreamSource;
+import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import io.ray.streaming.runtime.util.ReflectionUtils;
import java.lang.reflect.Method;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
import java.util.stream.Collectors;
-import org.msgpack.core.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,7 +73,7 @@ public class PythonGateway {
Preconditions.checkNotNull(streamingContext);
try {
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
- streamingContext, PythonFunction.fromFunction(pySourceFunc));
+ streamingContext, new PythonFunction(pySourceFunc));
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
return serializer.serialize(getReferenceId(pythonStreamSource));
} catch (Exception e) {
@@ -84,7 +89,7 @@ public class PythonGateway {
}
public byte[] createPyFunc(byte[] pyFunc) {
- PythonFunction function = PythonFunction.fromFunction(pyFunc);
+ PythonFunction function = new PythonFunction(pyFunc);
referenceMap.put(getReferenceId(function), function);
return serializer.serialize(getReferenceId(function));
}
@@ -98,15 +103,21 @@ public class PythonGateway {
public byte[] callFunction(byte[] paramsBytes) {
try {
List params = (List) serializer.deserialize(paramsBytes);
- params = processReferenceParameters(params);
+ params = processParameters(params);
LOG.info("callFunction params {}", params);
String className = (String) params.get(0);
String funcName = (String) params.get(1);
Class> clz = Class.forName(className, true, this.getClass().getClassLoader());
- Method method = ReflectionUtils.findMethod(clz, funcName);
+ Class[] paramsTypes = params.subList(2, params.size()).stream()
+ .map(Object::getClass).toArray(Class[]::new);
+ Method method = findMethod(clz, funcName, paramsTypes);
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
- referenceMap.put(getReferenceId(result), result);
- return serializer.serialize(getReferenceId(result));
+ if (returnReference(result)) {
+ referenceMap.put(getReferenceId(result), result);
+ return serializer.serialize(getReferenceId(result));
+ } else {
+ return serializer.serialize(result);
+ }
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -115,31 +126,78 @@ public class PythonGateway {
public byte[] callMethod(byte[] paramsBytes) {
try {
List params = (List) serializer.deserialize(paramsBytes);
- params = processReferenceParameters(params);
+ params = processParameters(params);
LOG.info("callMethod params {}", params);
Object obj = params.get(0);
String methodName = (String) params.get(1);
- Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
+ Class> clz = obj.getClass();
+ Class[] paramsTypes = params.subList(2, params.size()).stream()
+ .map(Object::getClass).toArray(Class[]::new);
+ Method method = findMethod(clz, methodName, paramsTypes);
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
- referenceMap.put(getReferenceId(result), result);
- return serializer.serialize(getReferenceId(result));
+ if (returnReference(result)) {
+ referenceMap.put(getReferenceId(result), result);
+ return serializer.serialize(getReferenceId(result));
+ } else {
+ return serializer.serialize(result);
+ }
} catch (Exception e) {
throw new RuntimeException(e);
}
}
- private List processReferenceParameters(List params) {
- return params.stream().map(this::processReferenceParameter)
+ private static Method findMethod(Class> cls, String methodName, Class[] paramsTypes) {
+ List methods = ReflectionUtils.findMethods(cls, methodName);
+ if (methods.size() == 1) {
+ return methods.get(0);
+ }
+ // Convert all params types to primitive types if it's boxed type
+ Class[] unwrappedTypes = Arrays.stream(paramsTypes)
+ .map((Function) Primitives::unwrap)
+ .toArray(Class[]::new);
+ Optional any = methods.stream()
+ .filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
+ Arrays.equals(m.getParameterTypes(), unwrappedTypes))
+ .findAny();
+ Preconditions.checkArgument(any.isPresent(),
+ String.format("Method %s with type %s doesn't exist on class %s",
+ methodName, Arrays.toString(paramsTypes), cls));
+ return any.get();
+ }
+
+ private static boolean returnReference(Object value) {
+ return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
+ }
+
+ public byte[] newInstance(byte[] classNameBytes) {
+ String className = (String) serializer.deserialize(classNameBytes);
+ try {
+ Class> clz = Class.forName(className, true, this.getClass().getClassLoader());
+ Object instance = clz.newInstance();
+ referenceMap.put(getReferenceId(instance), instance);
+ return serializer.serialize(getReferenceId(instance));
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ throw new IllegalArgumentException(
+ String.format("Create instance for class %s failed", className), e);
+ }
+ }
+
+ private List processParameters(List params) {
+ return params.stream().map(this::processParameter)
.collect(Collectors.toList());
}
- private Object processReferenceParameter(Object o) {
+ private Object processParameter(Object o) {
if (o instanceof String) {
Object value = referenceMap.get(o);
if (value != null) {
return value;
}
}
+ // Since python can't represent byte/short, we convert all Byte/Short to Integer
+ if (o instanceof Byte || o instanceof Short) {
+ return ((Number) o).intValue();
+ }
return o;
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java
index f1de23c8c..deaaf74b3 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/JobSchedulerImpl.java
@@ -41,15 +41,11 @@ public class JobSchedulerImpl implements JobScheduler {
public void schedule(JobGraph jobGraph, Map jobConfig) {
this.jobConfig = jobConfig;
this.jobGraph = jobGraph;
- if (Ray.internal() == null) {
- System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
- Ray.init();
- }
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
List executionNodes = executionGraph.getExecutionNodeList();
boolean hasPythonNode = executionNodes.stream()
- .allMatch(node -> node.getLanguage() == Language.PYTHON);
+ .anyMatch(node -> node.getLanguage() == Language.PYTHON);
RemoteCall.ExecutionGraph executionGraphPb = null;
if (hasPythonNode) {
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java
index 171375ed2..04520b441 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/schedule/TaskAssignerImpl.java
@@ -2,6 +2,8 @@ package io.ray.streaming.runtime.schedule;
import io.ray.api.BaseActor;
import io.ray.api.Ray;
+import io.ray.api.RayActor;
+import io.ray.api.RayPyActor;
import io.ray.api.function.PyActorClass;
import io.ray.streaming.jobgraph.JobEdge;
import io.ray.streaming.jobgraph.JobGraph;
@@ -15,8 +17,11 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class TaskAssignerImpl implements TaskAssigner {
+ private static final Logger LOG = LoggerFactory.getLogger(TaskAssignerImpl.class);
/**
* Assign an optimized logical plan to execution graph.
@@ -61,11 +66,17 @@ public class TaskAssignerImpl implements TaskAssigner {
private BaseActor createWorker(JobVertex jobVertex) {
switch (jobVertex.getLanguage()) {
- case PYTHON:
- return Ray.createActor(
+ case PYTHON: {
+ RayPyActor worker = Ray.createActor(
new PyActorClass("ray.streaming.runtime.worker", "JobWorker"));
- case JAVA:
- return Ray.createActor(JobWorker::new);
+ LOG.info("Created python worker {}", worker);
+ return worker;
+ }
+ case JAVA: {
+ RayActor worker = Ray.createActor(JobWorker::new);
+ LOG.info("Created java worker {}", worker);
+ return worker;
+ }
default:
throw new UnsupportedOperationException(
"Unsupported language " + jobVertex.getLanguage());
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java
new file mode 100644
index 000000000..17557b9ac
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/CrossLangSerializer.java
@@ -0,0 +1,62 @@
+package io.ray.streaming.runtime.serialization;
+
+import io.ray.streaming.message.KeyRecord;
+import io.ray.streaming.message.Record;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * A serializer for cross-lang serialization between java/python.
+ * TODO implements a more sophisticated serialization framework
+ */
+public class CrossLangSerializer implements Serializer {
+ private static final byte RECORD_TYPE_ID = 0;
+ private static final byte KEY_RECORD_TYPE_ID = 1;
+
+ private MsgPackSerializer msgPackSerializer = new MsgPackSerializer();
+
+ public byte[] serialize(Object object) {
+ Record record = (Record) object;
+ Object value = record.getValue();
+ Class extends Record> clz = record.getClass();
+ if (clz == Record.class) {
+ return msgPackSerializer.serialize(Arrays.asList(
+ RECORD_TYPE_ID, record.getStream(), value));
+ } else if (clz == KeyRecord.class) {
+ KeyRecord keyRecord = (KeyRecord) record;
+ Object key = keyRecord.getKey();
+ return msgPackSerializer.serialize(Arrays.asList(
+ KEY_RECORD_TYPE_ID, keyRecord.getStream(), key, value));
+ } else {
+ throw new UnsupportedOperationException(
+ String.format("Serialize %s is unsupported.", record));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public Record deserialize(byte[] bytes) {
+ List list = (List) msgPackSerializer.deserialize(bytes);
+ Byte typeId = (Byte) list.get(0);
+ switch (typeId) {
+ case RECORD_TYPE_ID: {
+ String stream = (String) list.get(1);
+ Object value = list.get(2);
+ Record record = new Record(value);
+ record.setStream(stream);
+ return record;
+ }
+ case KEY_RECORD_TYPE_ID: {
+ String stream = (String) list.get(1);
+ Object key = list.get(2);
+ Object value = list.get(3);
+ KeyRecord keyRecord = new KeyRecord(key, value);
+ keyRecord.setStream(stream);
+ return keyRecord;
+ }
+ default:
+ throw new UnsupportedOperationException("Unsupported type " + typeId);
+
+ }
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java
new file mode 100644
index 000000000..d7a1a2649
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/JavaSerializer.java
@@ -0,0 +1,15 @@
+package io.ray.streaming.runtime.serialization;
+
+import io.ray.runtime.serializer.FstSerializer;
+
+public class JavaSerializer implements Serializer {
+ @Override
+ public byte[] serialize(Object object) {
+ return FstSerializer.encode(object);
+ }
+
+ @Override
+ public T deserialize(byte[] bytes) {
+ return FstSerializer.decode(bytes);
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java
similarity index 90%
rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java
rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java
index 20415a438..2fc9a2c37 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/MsgPackSerializer.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/MsgPackSerializer.java
@@ -1,4 +1,4 @@
-package io.ray.streaming.runtime.python;
+package io.ray.streaming.runtime.serialization;
import com.google.common.io.BaseEncoding;
import java.util.ArrayList;
@@ -31,6 +31,10 @@ public class MsgPackSerializer {
Class> clz = obj.getClass();
if (clz == Boolean.class) {
packer.packBoolean((Boolean) obj);
+ } else if (clz == Byte.class) {
+ packer.packByte((Byte) obj);
+ } else if (clz == Short.class) {
+ packer.packShort((Short) obj);
} else if (clz == Integer.class) {
packer.packInt((Integer) obj);
} else if (clz == Long.class) {
@@ -84,7 +88,11 @@ public class MsgPackSerializer {
return value.asBooleanValue().getBoolean();
case INTEGER:
IntegerValue iv = value.asIntegerValue();
- if (iv.isInIntRange()) {
+ if (iv.isInByteRange()) {
+ return iv.toByte();
+ } else if (iv.isInShortRange()) {
+ return iv.toShort();
+ } else if (iv.isInIntRange()) {
return iv.toInt();
} else if (iv.isInLongRange()) {
return iv.toLong();
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java
new file mode 100644
index 000000000..b3a3184d7
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/serialization/Serializer.java
@@ -0,0 +1,12 @@
+package io.ray.streaming.runtime.serialization;
+
+public interface Serializer {
+ byte CROSS_LANG_TYPE_ID = 0;
+ byte JAVA_TYPE_ID = 1;
+ byte PYTHON_TYPE_ID = 2;
+
+ byte[] serialize(Object object);
+
+ T deserialize(byte[] bytes);
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java
index b152ca3b7..8506560d9 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder.java
@@ -20,7 +20,7 @@ import java.util.Map;
*/
public class ChannelCreationParametersBuilder {
- public class Parameter {
+ public static class Parameter {
private ActorId actorId;
private FunctionDescriptor asyncFunctionDescriptor;
@@ -138,7 +138,7 @@ public class ChannelCreationParametersBuilder {
parameter.setAsyncFunctionDescriptor(pyAsyncFunctionDesc);
parameter.setSyncFunctionDescriptor(pySyncFunctionDesc);
} else {
- Preconditions.checkArgument(false, "Invalid actor type");
+ throw new IllegalArgumentException("Invalid actor type");
}
parameters.add(parameter);
}
@@ -152,10 +152,10 @@ public class ChannelCreationParametersBuilder {
}
public String toString() {
- String str = "";
+ StringBuilder str = new StringBuilder();
for (Parameter param : parameters) {
- str += param.toString();
+ str.append(param.toString());
}
- return str;
+ return str.toString();
}
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java
index 64e17f59c..b69396b43 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java
@@ -40,7 +40,7 @@ public class DataReader {
}
long timerInterval = Long.parseLong(
conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
- String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
+ String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false;
if (Config.MEMORY_CHANNEL.equals(channelType)) {
isMock = true;
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java
index 25e02940e..39678aebb 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java
@@ -37,7 +37,7 @@ public class DataWriter {
Map conf) {
Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size());
- ChannelCreationParametersBuilder initialParameters =
+ ChannelCreationParametersBuilder initParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelID::idStrToBytes).toArray(byte[][]::new);
@@ -47,13 +47,14 @@ public class DataWriter {
for (int i = 0; i < outputChannels.size(); i++) {
msgIds[i] = 0;
}
- String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
+ String channelType = conf.get(Config.CHANNEL_TYPE);
boolean isMock = false;
- if (Config.MEMORY_CHANNEL.equals(channelType)) {
+ if (Config.MEMORY_CHANNEL.equalsIgnoreCase(channelType)) {
isMock = true;
+ LOGGER.info("Using memory channel");
}
this.nativeWriterPtr = createWriterNative(
- initialParameters,
+ initParameters,
outputChannelsBytes,
msgIds,
channelSize,
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java
index d3f26a06a..5852220af 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ReflectionUtils.java
@@ -19,6 +19,7 @@ public class ReflectionUtils {
/**
* For covariant return type, return the most specific method.
+ *
* @return all methods named by {@code methodName},
*/
public static List findMethods(Class> cls, String methodName) {
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java
index 75d587c5a..2433d18e9 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java
@@ -1,5 +1,6 @@
package io.ray.streaming.runtime.worker;
+import io.ray.api.Ray;
import io.ray.streaming.runtime.core.graph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.ExecutionNode;
import io.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
@@ -14,11 +15,8 @@ import io.ray.streaming.runtime.worker.context.WorkerContext;
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask;
-import io.ray.streaming.util.Config;
-
import java.io.Serializable;
import java.util.Map;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -27,6 +25,8 @@ import org.slf4j.LoggerFactory;
*/
public class JobWorker implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
+ // special flag to indicate this actor not ready
+ private static final byte[] NOT_READY_FLAG = new byte[4];
static {
EnvUtil.loadNativeLibraries();
@@ -53,12 +53,11 @@ public class JobWorker implements Serializable {
this.nodeType = executionNode.getNodeType();
this.streamProcessor = ProcessBuilder
- .buildProcessor(executionNode.getStreamOperator());
- LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
+ .buildProcessor(executionNode.getStreamOperator());
+ LOGGER.info("Initializing StreamWorker, pid {}, taskId: {}, operator: {}.",
+ EnvUtil.getJvmPid(), taskId, streamProcessor);
- String channelType = (String) this.config.getOrDefault(
- Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
- if (channelType.equals(Config.NATIVE_CHANNEL)) {
+ if (!Ray.getRuntimeContext().isSingleProcess()) {
transferHandler = new TransferHandler();
}
task = createStreamTask();
@@ -124,6 +123,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor
*/
public byte[] onReaderMessageSync(byte[] buffer) {
+ if (transferHandler == null) {
+ return NOT_READY_FLAG;
+ }
return transferHandler.onReaderMessageSync(buffer);
}
@@ -139,6 +141,9 @@ public class JobWorker implements Serializable {
* and receive result from this actor
*/
public byte[] onWriterMessageSync(byte[] buffer) {
+ if (transferHandler == null) {
+ return NOT_READY_FLAG;
+ }
return transferHandler.onWriterMessageSync(buffer);
}
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java
index a3fd7a470..8d642aeef 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java
@@ -1,7 +1,9 @@
package io.ray.streaming.runtime.worker.tasks;
-import io.ray.runtime.serializer.Serializer;
import io.ray.streaming.runtime.core.processor.Processor;
+import io.ray.streaming.runtime.serialization.CrossLangSerializer;
+import io.ray.streaming.runtime.serialization.JavaSerializer;
+import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.Message;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.util.Config;
@@ -10,11 +12,15 @@ public abstract class InputStreamTask extends StreamTask {
private volatile boolean running = true;
private volatile boolean stopped = false;
private long readTimeoutMillis;
+ private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
+ private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
super(taskId, processor, streamWorker);
readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
.getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
+ javaSerializer = new JavaSerializer();
+ crossLangSerializer = new CrossLangSerializer();
}
@Override
@@ -26,9 +32,15 @@ public abstract class InputStreamTask extends StreamTask {
while (running) {
Message item = reader.read(readTimeoutMillis);
if (item != null) {
- byte[] bytes = new byte[item.body().remaining()];
+ byte[] bytes = new byte[item.body().remaining() - 1];
+ byte typeId = item.body().get();
item.body().get(bytes);
- Object obj = Serializer.decode(bytes, Object.class);
+ Object obj;
+ if (typeId == Serializer.JAVA_TYPE_ID) {
+ obj = javaSerializer.deserialize(bytes);
+ } else {
+ obj = crossLangSerializer.deserialize(bytes);
+ }
processor.process(obj);
}
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java
index d16cc029d..ca2e6aa99 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java
@@ -26,7 +26,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class StreamTask implements Runnable {
-
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
protected int taskId;
@@ -53,8 +52,8 @@ public abstract class StreamTask implements Runnable {
String queueSize = worker.getConfig()
.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
queueConf.put(Config.CHANNEL_SIZE, queueSize);
- String channelType = worker.getConfig()
- .getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
+ String channelType = Ray.getRuntimeContext().isSingleProcess() ?
+ Config.MEMORY_CHANNEL : Config.NATIVE_CHANNEL;
queueConf.put(Config.CHANNEL_TYPE, channelType);
ExecutionGraph executionGraph = worker.getExecutionGraph();
@@ -82,7 +81,7 @@ public abstract class StreamTask implements Runnable {
LOG.info("Create DataWriter succeed.");
writers.put(edge, writer);
Partition partition = edge.getPartition();
- collectors.add(new OutputCollector(channelIDs, writer, partition));
+ collectors.add(new OutputCollector(writer, channelIDs, outputActors.values(), partition));
}
}
@@ -106,8 +105,8 @@ public abstract class StreamTask implements Runnable {
reader = new DataReader(channelIDs, inputActors, queueConf);
}
- RuntimeContext runtimeContext = new RayRuntimeContext(worker.getExecutionTask(),
- worker.getConfig(), executionNode.getParallelism());
+ RuntimeContext runtimeContext = new RayRuntimeContext(
+ worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
processor.open(collectors, runtimeContext);
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java
index e757f14e1..593851a86 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/BaseUnitTest.java
@@ -24,11 +24,13 @@ public abstract class BaseUnitTest {
@BeforeMethod
public void testBegin(Method method) {
- LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " began >>>>>>>>>>>>>>>>>>>>");
+ LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} began >>>>>>>>>>>>>>>>>>>>",
+ method.getDeclaringClass(), method.getName());
}
@AfterMethod
public void testEnd(Method method) {
- LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: " + method.getName() + " end >>>>>>>>>>>>>>>>>>");
+ LOG.info(">>>>>>>>>>>>>>>>>>>> Test case: {}.{} end >>>>>>>>>>>>>>>>>>>>",
+ method.getDeclaringClass(), method.getName());
}
}
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java
index 920bc1f74..882fc5fb4 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java
@@ -80,7 +80,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
public static JobGraph buildJobGraph() {
StreamingContext streamingContext = StreamingContext.buildContext();
- DataStream dataStream = DataStreamSource.buildSource(streamingContext,
+ DataStream dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
StreamSink streamSink = dataStream.sink(x -> LOG.info(x));
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java
new file mode 100644
index 000000000..025f67e21
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/HybridStreamTest.java
@@ -0,0 +1,56 @@
+package io.ray.streaming.runtime.demo;
+
+import io.ray.api.Ray;
+import io.ray.streaming.api.context.StreamingContext;
+import io.ray.streaming.api.function.impl.FilterFunction;
+import io.ray.streaming.api.function.impl.MapFunction;
+import io.ray.streaming.api.stream.DataStreamSource;
+import io.ray.streaming.runtime.BaseUnitTest;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.concurrent.TimeUnit;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.annotations.Test;
+
+public class HybridStreamTest extends BaseUnitTest implements Serializable {
+ private static final Logger LOG = LoggerFactory.getLogger(HybridStreamTest.class);
+
+ public static class Mapper1 implements MapFunction {
+
+ @Override
+ public Object map(Object value) {
+ LOG.info("HybridStreamTest Mapper1 {}", value);
+ return value.toString();
+ }
+ }
+
+ public static class Filter1 implements FilterFunction {
+
+ @Override
+ public boolean filter(Object value) throws Exception {
+ LOG.info("HybridStreamTest Filter1 {}", value);
+ return !value.toString().contains("b");
+ }
+ }
+
+ @Test
+ public void testHybridDataStream() throws InterruptedException {
+ Ray.shutdown();
+ StreamingContext context = StreamingContext.buildContext();
+ DataStreamSource streamSource =
+ DataStreamSource.fromCollection(context, Arrays.asList("a", "b", "c"));
+ streamSource
+ .map(x -> x + x)
+ .asPythonStream()
+ .map("ray.streaming.tests.test_hybrid_stream", "map_func1")
+ .filter("ray.streaming.tests.test_hybrid_stream", "filter_func1")
+ .asJavaStream()
+ .sink(x -> System.out.println("HybridStreamTest: " + x));
+ context.execute("HybridStreamTestJob");
+ TimeUnit.SECONDS.sleep(3);
+ context.stop();
+ LOG.info("HybridStreamTest succeed");
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java
index 389c1bc1a..5669ad12f 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/WordCountTest.java
@@ -1,6 +1,7 @@
package io.ray.streaming.runtime.demo;
import com.google.common.collect.ImmutableMap;
+import io.ray.api.Ray;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
@@ -29,6 +30,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
@Test
public void testWordCount() {
+ Ray.shutdown();
StreamingContext streamingContext = StreamingContext.buildContext();
Map config = new HashMap<>();
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
@@ -36,7 +38,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config);
List text = new ArrayList<>();
text.add("hello world eagle eagle eagle");
- DataStreamSource streamSource = DataStreamSource.buildSource(streamingContext, text);
+ DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource
.flatMap((FlatMapFunction) (value, collector) -> {
String[] records = value.split(" ");
@@ -62,6 +64,7 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
}
}
Assert.assertEquals(wordCount, ImmutableMap.of("eagle", 3, "hello", 1));
+ streamingContext.stop();
}
private static class WordAndCount implements Serializable {
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java
index 51440dba6..5922cc578 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/PythonGatewayTest.java
@@ -3,6 +3,7 @@ package io.ray.streaming.runtime.python;
import io.ray.streaming.api.stream.StreamSink;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobGraphBuilder;
+import io.ray.streaming.runtime.serialization.MsgPackSerializer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java
index 7c2e7e7ff..2e8978c7a 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/schedule/TaskAssignerImplTest.java
@@ -57,7 +57,7 @@ public class TaskAssignerImplTest extends BaseUnitTest {
public JobGraph buildDataSyncPlan() {
StreamingContext streamingContext = StreamingContext.buildContext();
- DataStream dataStream = DataStreamSource.buildSource(streamingContext,
+ DataStream dataStream = DataStreamSource.fromCollection(streamingContext,
Lists.newArrayList("a", "b", "c"));
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java
new file mode 100644
index 000000000..be92792b6
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/CrossLangSerializerTest.java
@@ -0,0 +1,26 @@
+package io.ray.streaming.runtime.serialization;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+import org.apache.commons.lang3.builder.EqualsBuilder;
+import io.ray.streaming.message.KeyRecord;
+import io.ray.streaming.message.Record;
+import org.testng.annotations.Test;
+
+public class CrossLangSerializerTest {
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testSerialize() {
+ CrossLangSerializer serializer = new CrossLangSerializer();
+ Record record = new Record("value");
+ record.setStream("stream1");
+ assertTrue(EqualsBuilder.reflectionEquals(record,
+ serializer.deserialize(serializer.serialize(record))));
+ KeyRecord keyRecord = new KeyRecord("key", "value");
+ keyRecord.setStream("stream2");
+ assertEquals(keyRecord,
+ serializer.deserialize(serializer.serialize(keyRecord)));
+ }
+}
\ No newline at end of file
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java
similarity index 59%
rename from streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java
rename to streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java
index b2213538b..44568df8d 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/python/MsgPackSerializerTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/serialization/MsgPackSerializerTest.java
@@ -1,4 +1,7 @@
-package io.ray.streaming.runtime.python;
+package io.ray.streaming.runtime.serialization;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;
@@ -6,25 +9,37 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.testng.annotations.Test;
-import static org.testng.Assert.assertEquals;
-import static org.testng.Assert.assertTrue;
@SuppressWarnings("unchecked")
public class MsgPackSerializerTest {
+ @Test
+ public void testSerializeByte() {
+ MsgPackSerializer serializer = new MsgPackSerializer();
+
+ assertEquals(serializer.deserialize(
+ serializer.serialize((byte)1)), (byte)1);
+ }
+
@Test
public void testSerialize() {
MsgPackSerializer serializer = new MsgPackSerializer();
+ assertEquals(serializer.deserialize
+ (serializer.serialize(Short.MAX_VALUE)), Short.MAX_VALUE);
+ assertEquals(serializer.deserialize(
+ serializer.serialize(Integer.MAX_VALUE)), Integer.MAX_VALUE);
+ assertEquals(serializer.deserialize(
+ serializer.serialize(Long.MAX_VALUE)), Long.MAX_VALUE);
+
Map map = new HashMap();
List list = new ArrayList<>();
list.add(null);
list.add(true);
- list.add(1);
list.add(1.0d);
list.add("str");
map.put("k1", "value1");
- map.put("k2", 2);
+ map.put("k2", new HashMap<>());
map.put("k3", list);
byte[] bytes = serializer.serialize(map);
Object o = serializer.deserialize(bytes);
diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java
index cfa34dd04..c48293cea 100644
--- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java
@@ -5,6 +5,7 @@ import io.ray.api.Ray;
import io.ray.api.RayActor;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.ActorCreationOptions.Builder;
+import io.ray.runtime.config.RayConfig;
import io.ray.streaming.api.context.StreamingContext;
import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
@@ -67,7 +68,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true");
- // ray init
+ RayConfig.reset();
Ray.init();
}
@@ -142,6 +143,14 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000)
public void testWordCount() {
+ Ray.shutdown();
+ System.setProperty("ray.resources", "CPU:4,RES-A:4");
+ System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
+
+ System.setProperty("ray.run-mode", "CLUSTER");
+ System.setProperty("ray.redirect-output", "true");
+ // ray init
+ Ray.init();
LOGGER.info("testWordCount");
LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
System.getProperty("ray.run-mode"));
@@ -157,7 +166,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
streamingContext.withConfig(config);
List text = new ArrayList<>();
text.add("hello world eagle eagle eagle");
- DataStreamSource streamSource = DataStreamSource.buildSource(streamingContext, text);
+ DataStreamSource streamSource = DataStreamSource.fromCollection(streamingContext, text);
streamSource
.flatMap((FlatMapFunction) (value, collector) -> {
String[] records = value.split(" ");
@@ -176,7 +185,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
serializeResultToFile(resultFile, wordCount);
});
- streamingContext.execute("testWordCount");
+ streamingContext.execute("testSQWordCount");
Map checkWordCount =
(Map) deserializeResultFromFile(resultFile);
diff --git a/streaming/java/test.sh b/streaming/java/test.sh
index e3225452c..ecf9770a8 100755
--- a/streaming/java/test.sh
+++ b/streaming/java/test.sh
@@ -23,8 +23,11 @@ bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_on
echo "Running streaming tests."
java -cp "$ROOT_DIR"/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
- org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml
+ org.testng.TestNG -d /tmp/ray_streaming_java_test_output "$ROOT_DIR"/testng.xml ||
exit_code=$?
+if [ -z ${exit_code+x} ]; then
+ exit_code=0
+fi
echo "Streaming TestNG results"
if [ -f "/tmp/ray_streaming_java_test_output/testng-results.xml" ] ; then
cat /tmp/ray_streaming_java_test_output/testng-results.xml
diff --git a/streaming/python/collector.py b/streaming/python/collector.py
index cc803eaf4..12b6c096b 100644
--- a/streaming/python/collector.py
+++ b/streaming/python/collector.py
@@ -1,10 +1,13 @@
import logging
-import pickle
import typing
from abc import ABC, abstractmethod
+from ray import Language
+from ray.actor import ActorHandle
+from ray.streaming import function
from ray.streaming import message
from ray.streaming import partition
+from ray.streaming.runtime import serialization
from ray.streaming.runtime.transfer import ChannelID, DataWriter
logger = logging.getLogger(__name__)
@@ -31,19 +34,46 @@ class CollectionCollector(Collector):
class OutputCollector(Collector):
- def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
+ def __init__(self, writer: DataWriter, channel_ids: typing.List[str],
+ target_actors: typing.List[ActorHandle],
partition_func: partition.Partition):
- self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._writer = writer
+ self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
+ self._target_languages = []
+ for actor in target_actors:
+ if actor._ray_actor_language == Language.PYTHON:
+ self._target_languages.append(function.Language.PYTHON)
+ elif actor._ray_actor_language == Language.JAVA:
+ self._target_languages.append(function.Language.JAVA)
+ else:
+ raise Exception("Unsupported language {}"
+ .format(actor._ray_actor_language))
self._partition_func = partition_func
+ self.python_serializer = serialization.PythonSerializer()
+ self.cross_lang_serializer = serialization.CrossLangSerializer()
logger.info(
"Create OutputCollector, channel_ids {}, partition_func {}".format(
channel_ids, partition_func))
def collect(self, record):
- partitions = self._partition_func.partition(record,
- len(self._channel_ids))
- serialized_message = pickle.dumps(record)
+ partitions = self._partition_func \
+ .partition(record, len(self._channel_ids))
+ python_buffer = None
+ cross_lang_buffer = None
for partition_index in partitions:
- self._writer.write(self._channel_ids[partition_index],
- serialized_message)
+ if self._target_languages[partition_index] == \
+ function.Language.PYTHON:
+ # avoid repeated serialization
+ if python_buffer is None:
+ python_buffer = self.python_serializer.serialize(record)
+ self._writer.write(
+ self._channel_ids[partition_index],
+ serialization._PYTHON_TYPE_ID + python_buffer)
+ else:
+ # avoid repeated serialization
+ if cross_lang_buffer is None:
+ cross_lang_buffer = self.cross_lang_serializer.serialize(
+ record)
+ self._writer.write(
+ self._channel_ids[partition_index],
+ serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer)
diff --git a/streaming/python/datastream.py b/streaming/python/datastream.py
index 39a067a6a..26297da11 100644
--- a/streaming/python/datastream.py
+++ b/streaming/python/datastream.py
@@ -1,4 +1,4 @@
-from abc import ABC
+from abc import ABC, abstractmethod
from ray.streaming import function
from ray.streaming import partition
@@ -19,7 +19,6 @@ class Stream(ABC):
self.streaming_context = input_stream.streaming_context
else:
self.streaming_context = streaming_context
- self.parallelism = 1
def get_streaming_context(self):
return self.streaming_context
@@ -29,7 +28,8 @@ class Stream(ABC):
Returns:
the parallelism of this transformation
"""
- return self.parallelism
+ return self._gateway_client(). \
+ call_method(self._j_stream, "getParallelism")
def set_parallelism(self, parallelism: int):
"""Sets the parallelism of this transformation
@@ -40,7 +40,6 @@ class Stream(ABC):
Returns:
self
"""
- self.parallelism = parallelism
self._gateway_client(). \
call_method(self._j_stream, "setParallelism", parallelism)
return self
@@ -60,6 +59,10 @@ class Stream(ABC):
return self._gateway_client(). \
call_method(self._j_stream, "getId")
+ @abstractmethod
+ def get_language(self):
+ pass
+
def _gateway_client(self):
return self.get_streaming_context()._gateway_client
@@ -75,6 +78,9 @@ class DataStream(Stream):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
+ def get_language(self):
+ return function.Language.PYTHON
+
def map(self, func):
"""
Applies a Map transformation on a :class:`DataStream`.
@@ -158,6 +164,7 @@ class DataStream(Stream):
Returns:
A KeyDataStream
"""
+ self._check_partition_call()
if not isinstance(func, function.KeyFunction):
func = function.SimpleKeyFunction(func)
j_func = self._gateway_client().create_py_func(
@@ -175,6 +182,7 @@ class DataStream(Stream):
Returns:
The DataStream with broadcast partitioning set.
"""
+ self._check_partition_call()
self._gateway_client().call_method(self._j_stream, "broadcast")
return self
@@ -191,6 +199,7 @@ class DataStream(Stream):
Returns:
The DataStream with specified partitioning set.
"""
+ self._check_partition_call()
if not isinstance(partition_func, partition.Partition):
partition_func = partition.SimplePartition(partition_func)
j_partition = self._gateway_client().create_py_func(
@@ -199,6 +208,16 @@ class DataStream(Stream):
call_method(self._j_stream, "partitionBy", j_partition)
return self
+ def _check_partition_call(self):
+ """
+ If parent stream is a java stream, we can't call partition related
+ methods in the python stream
+ """
+ if self.input_stream is not None and \
+ self.input_stream.get_language() == function.Language.JAVA:
+ raise Exception("Partition related methods can't be called on a "
+ "python stream if parent stream is a java stream.")
+
def sink(self, func):
"""
Create a StreamSink with the given sink.
@@ -217,8 +236,97 @@ class DataStream(Stream):
call_method(self._j_stream, "sink", j_func)
return StreamSink(self, j_stream, func)
+ def as_java_stream(self):
+ """
+ Convert this stream as a java JavaDataStream.
+ The converted stream and this stream are the same logical stream,
+ which has same stream id. Changes in converted stream will be reflected
+ in this stream and vice versa.
+ """
+ j_stream = self._gateway_client(). \
+ call_method(self._j_stream, "asJavaStream")
+ return JavaDataStream(self, j_stream)
-class KeyDataStream(Stream):
+
+class JavaDataStream(Stream):
+ """
+ Represents a stream of data which applies a transformation executed by
+ java. It's also a wrapper of java
+ `org.ray.streaming.api.stream.DataStream`
+ """
+
+ def __init__(self, input_stream, j_stream, streaming_context=None):
+ super().__init__(
+ input_stream, j_stream, streaming_context=streaming_context)
+
+ def get_language(self):
+ return function.Language.JAVA
+
+ def map(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.map"""
+ return JavaDataStream(self, self._unary_call("map", java_func_class))
+
+ def flat_map(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.flatMap"""
+ return JavaDataStream(self, self._unary_call("flatMap",
+ java_func_class))
+
+ def filter(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.filter"""
+ return JavaDataStream(self, self._unary_call("filter",
+ java_func_class))
+
+ def key_by(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.keyBy"""
+ self._check_partition_call()
+ return JavaKeyDataStream(self,
+ self._unary_call("keyBy", java_func_class))
+
+ def broadcast(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.broadcast"""
+ self._check_partition_call()
+ return JavaDataStream(self,
+ self._unary_call("broadcast", java_func_class))
+
+ def partition_by(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.partitionBy"""
+ self._check_partition_call()
+ return JavaDataStream(self,
+ self._unary_call("partitionBy", java_func_class))
+
+ def sink(self, java_func_class):
+ """See org.ray.streaming.api.stream.DataStream.sink"""
+ return JavaStreamSink(self, self._unary_call("sink", java_func_class))
+
+ def as_python_stream(self):
+ """
+ Convert this stream as a python DataStream.
+ The converted stream and this stream are the same logical stream,
+ which has same stream id. Changes in converted stream will be reflected
+ in this stream and vice versa.
+ """
+ j_stream = self._gateway_client(). \
+ call_method(self._j_stream, "asPythonStream")
+ return DataStream(self, j_stream)
+
+ def _check_partition_call(self):
+ """
+ If parent stream is a python stream, we can't call partition related
+ methods in the java stream
+ """
+ if self.input_stream is not None and \
+ self.input_stream.get_language() == function.Language.PYTHON:
+ raise Exception("Partition related methods can't be called on a"
+ "java stream if parent stream is a python stream.")
+
+ def _unary_call(self, func_name, java_func_class):
+ j_func = self._gateway_client().new_instance(java_func_class)
+ j_stream = self._gateway_client(). \
+ call_method(self._j_stream, func_name, j_func)
+ return j_stream
+
+
+class KeyDataStream(DataStream):
"""Represents a DataStream returned by a key-by operation.
Wrapper of java io.ray.streaming.python.stream.PythonKeyDataStream
"""
@@ -251,6 +359,43 @@ class KeyDataStream(Stream):
call_method(self._j_stream, "reduce", j_func)
return DataStream(self, j_stream)
+ def as_java_stream(self):
+ """
+ Convert this stream as a java KeyDataStream.
+ The converted stream and this stream are the same logical stream,
+ which has same stream id. Changes in converted stream will be reflected
+ in this stream and vice versa.
+ """
+ j_stream = self._gateway_client(). \
+ call_method(self._j_stream, "asJavaStream")
+ return JavaKeyDataStream(self, j_stream)
+
+
+class JavaKeyDataStream(JavaDataStream):
+ """
+ Represents a DataStream returned by a key-by operation in java.
+ Wrapper of org.ray.streaming.api.stream.KeyDataStream
+ """
+
+ def __init__(self, input_stream, j_stream):
+ super().__init__(input_stream, j_stream)
+
+ def reduce(self, java_func_class):
+ """See org.ray.streaming.api.stream.KeyDataStream.reduce"""
+ return JavaDataStream(self,
+ super()._unary_call("reduce", java_func_class))
+
+ def as_python_stream(self):
+ """
+ Convert this stream as a python KeyDataStream.
+ The converted stream and this stream are the same logical stream,
+ which has same stream id. Changes in converted stream will be reflected
+ in this stream and vice versa.
+ """
+ j_stream = self._gateway_client(). \
+ call_method(self._j_stream, "asPythonStream")
+ return KeyDataStream(self, j_stream)
+
class StreamSource(DataStream):
"""Represents a source of the DataStream.
@@ -261,9 +406,12 @@ class StreamSource(DataStream):
super().__init__(None, j_stream, streaming_context=streaming_context)
self.source_func = source_func
+ def get_language(self):
+ return function.Language.PYTHON
+
@staticmethod
def build_source(streaming_context, func):
- """Build a StreamSource source from a collection.
+ """Build a StreamSource source from a source function.
Args:
streaming_context: Stream context
func: A instance of `SourceFunction`
@@ -275,6 +423,34 @@ class StreamSource(DataStream):
return StreamSource(j_stream, streaming_context, func)
+class JavaStreamSource(JavaDataStream):
+ """Represents a source of the java DataStream.
+ Wrapper of java org.ray.streaming.api.stream.DataStreamSource
+ """
+
+ def __init__(self, j_stream, streaming_context):
+ super().__init__(None, j_stream, streaming_context=streaming_context)
+
+ def get_language(self):
+ return function.Language.JAVA
+
+ @staticmethod
+ def build_source(streaming_context, java_source_func_class):
+ """Build a java StreamSource source from a java source function.
+ Args:
+ streaming_context: Stream context
+ java_source_func_class: qualified class name of java SourceFunction
+ Returns:
+ A java StreamSource
+ """
+ j_func = streaming_context._gateway_client() \
+ .new_instance(java_source_func_class)
+ j_stream = streaming_context._gateway_client() \
+ .call_function("org.ray.streaming.api.stream.DataStreamSource"
+ "fromSource", streaming_context._j_ctx, j_func)
+ return JavaStreamSource(j_stream, streaming_context)
+
+
class StreamSink(Stream):
"""Represents a sink of the DataStream.
Wrapper of java io.ray.streaming.python.stream.PythonStreamSink
@@ -282,3 +458,18 @@ class StreamSink(Stream):
def __init__(self, input_stream, j_stream, func):
super().__init__(input_stream, j_stream)
+
+ def get_language(self):
+ return function.Language.PYTHON
+
+
+class JavaStreamSink(Stream):
+ """Represents a sink of the java DataStream.
+ Wrapper of java org.ray.streaming.api.stream.StreamSink
+ """
+
+ def __init__(self, input_stream, j_stream):
+ super().__init__(input_stream, j_stream)
+
+ def get_language(self):
+ return function.Language.JAVA
diff --git a/streaming/python/function.py b/streaming/python/function.py
index 9a9a22a19..8d38ae6bc 100644
--- a/streaming/python/function.py
+++ b/streaming/python/function.py
@@ -1,13 +1,19 @@
+import enum
import importlib
import inspect
import sys
-from abc import ABC, abstractmethod
import typing
+from abc import ABC, abstractmethod
from ray import cloudpickle
from ray.streaming.runtime import gateway_client
+class Language(enum.Enum):
+ JAVA = 0
+ PYTHON = 1
+
+
class Function(ABC):
"""The base interface for all user-defined functions."""
@@ -60,6 +66,7 @@ class MapFunction(Function):
for each input element.
"""
+ @abstractmethod
def map(self, value):
pass
@@ -70,6 +77,7 @@ class FlatMapFunction(Function):
transform them into zero, one, or more elements.
"""
+ @abstractmethod
def flat_map(self, value, collector):
"""Takes an element from the input data set and transforms it into zero,
one, or more elements.
@@ -87,6 +95,7 @@ class FilterFunction(Function):
The predicate decides whether to keep the element, or to discard it.
"""
+ @abstractmethod
def filter(self, value):
"""The filter function that evaluates the predicate.
@@ -106,6 +115,7 @@ class KeyFunction(Function):
deterministic key for that object.
"""
+ @abstractmethod
def key_by(self, value):
"""User-defined function that deterministically extracts the key from
an object.
@@ -126,6 +136,7 @@ class ReduceFunction(Function):
them into one.
"""
+ @abstractmethod
def reduce(self, old_value, new_value):
"""
The core method of ReduceFunction, combining two values into one value
@@ -145,6 +156,7 @@ class ReduceFunction(Function):
class SinkFunction(Function):
"""Interface for implementing user defined sink functionality."""
+ @abstractmethod
def sink(self, value):
"""Writes the given value to the sink. This function is called for
every record."""
@@ -283,7 +295,8 @@ def load_function(descriptor_func_bytes: bytes):
Returns:
a streaming function
"""
- function_bytes, module_name, class_name, function_name, function_interface\
+ assert len(descriptor_func_bytes) > 0
+ function_bytes, module_name, function_name, function_interface\
= gateway_client.deserialize(descriptor_func_bytes)
if function_bytes:
return deserialize(function_bytes)
@@ -292,16 +305,18 @@ def load_function(descriptor_func_bytes: bytes):
assert function_interface
function_interface = getattr(sys.modules[__name__], function_interface)
mod = importlib.import_module(module_name)
- if class_name:
- assert function_name is None
- cls = getattr(mod, class_name)
- assert issubclass(cls, function_interface)
- return cls()
- else:
- assert function_name
- func = getattr(mod, function_name)
+ assert function_name
+ func = getattr(mod, function_name)
+ # If func is a python function, user function is a simple python
+ # function, which will be wrapped as a SimpleXXXFunction.
+ # If func is a python class, user function is a sub class
+ # of XXXFunction.
+ if inspect.isfunction(func):
simple_func_class = _get_simple_function_class(function_interface)
return simple_func_class(func)
+ else:
+ assert issubclass(func, function_interface)
+ return func()
def _get_simple_function_class(function_interface):
diff --git a/streaming/python/message.py b/streaming/python/message.py
index fab29d4bf..94d928e1d 100644
--- a/streaming/python/message.py
+++ b/streaming/python/message.py
@@ -8,6 +8,14 @@ class Record:
def __repr__(self):
return "Record(%s)".format(self.value)
+ def __eq__(self, other):
+ if type(self) is type(other):
+ return (self.stream, self.value) == (other.stream, other.value)
+ return False
+
+ def __hash__(self):
+ return hash((self.stream, self.value))
+
class KeyRecord(Record):
"""Data record in a keyed data stream"""
@@ -15,3 +23,12 @@ class KeyRecord(Record):
def __init__(self, key, value):
super().__init__(value)
self.key = key
+
+ def __eq__(self, other):
+ if type(self) is type(other):
+ return (self.stream, self.key, self.value) ==\
+ (other.stream, other.key, other.value)
+ return False
+
+ def __hash__(self):
+ return hash((self.stream, self.key, self.value))
diff --git a/streaming/python/partition.py b/streaming/python/partition.py
index 722fb7933..198fbe3d7 100644
--- a/streaming/python/partition.py
+++ b/streaming/python/partition.py
@@ -1,4 +1,5 @@
import importlib
+import inspect
from abc import ABC, abstractmethod
from ray import cloudpickle
@@ -96,22 +97,22 @@ def load_partition(descriptor_partition_bytes: bytes):
Returns:
partition function
"""
- partition_bytes, module_name, class_name, function_name =\
+ assert len(descriptor_partition_bytes) > 0
+ partition_bytes, module_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes:
return deserialize(partition_bytes)
else:
assert module_name
mod = importlib.import_module(module_name)
- # If class_name is not None, user partition is a sub class
- # of Partition.
- # If function_name is not None, user partition is a simple python
+ assert function_name
+ func = getattr(mod, function_name)
+ # If func is a python function, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
- if class_name:
- assert function_name is None
- cls = getattr(mod, class_name)
- return cls()
- else:
- assert function_name
- func = getattr(mod, function_name)
+ # If func is a python class, user partition is a sub class
+ # of Partition.
+ if inspect.isfunction(func):
return SimplePartition(func)
+ else:
+ assert issubclass(func, Partition)
+ return func()
diff --git a/streaming/python/runtime/gateway_client.py b/streaming/python/runtime/gateway_client.py
index 5477d9230..8fa4fac61 100644
--- a/streaming/python/runtime/gateway_client.py
+++ b/streaming/python/runtime/gateway_client.py
@@ -55,6 +55,11 @@ class GatewayClient:
call = self._python_gateway_actor.callMethod.remote(java_params)
return deserialize(ray.get(call))
+ def new_instance(self, java_class_name):
+ call = self._python_gateway_actor.newInstance.remote(
+ serialize(java_class_name))
+ return deserialize(ray.get(call))
+
def serialize(obj) -> bytes:
"""Serialize a python object which can be deserialized by `PythonGateway`
diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py
index 78396dce5..645827601 100644
--- a/streaming/python/runtime/graph.py
+++ b/streaming/python/runtime/graph.py
@@ -53,7 +53,9 @@ class ExecutionEdge:
self.src_node_id = edge_pb.src_node_id
self.target_node_id = edge_pb.target_node_id
partition_bytes = edge_pb.partition
- if language == Language.PYTHON:
+ # Sink node doesn't have partition function,
+ # so we only deserialize partition_bytes when it's not None or empty
+ if language == Language.PYTHON and partition_bytes:
self.partition = partition.load_partition(partition_bytes)
diff --git a/streaming/python/runtime/serialization.py b/streaming/python/runtime/serialization.py
new file mode 100644
index 000000000..a01bf4e2c
--- /dev/null
+++ b/streaming/python/runtime/serialization.py
@@ -0,0 +1,57 @@
+from abc import ABC, abstractmethod
+import pickle
+import msgpack
+from ray.streaming import message
+
+_RECORD_TYPE_ID = 0
+_KEY_RECORD_TYPE_ID = 1
+_CROSS_LANG_TYPE_ID = b"0"
+_JAVA_TYPE_ID = b"1"
+_PYTHON_TYPE_ID = b"2"
+
+
+class Serializer(ABC):
+ @abstractmethod
+ def serialize(self, obj):
+ pass
+
+ @abstractmethod
+ def deserialize(self, serialized_bytes):
+ pass
+
+
+class PythonSerializer(Serializer):
+ def serialize(self, obj):
+ return pickle.dumps(obj)
+
+ def deserialize(self, serialized_bytes):
+ return pickle.loads(serialized_bytes)
+
+
+class CrossLangSerializer(Serializer):
+ """Serialize stream element between java/python"""
+
+ def serialize(self, obj):
+ if type(obj) is message.Record:
+ fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
+ elif type(obj) is message.KeyRecord:
+ fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
+ else:
+ raise Exception("Unsupported value {}".format(obj))
+ return msgpack.packb(fields, use_bin_type=True)
+
+ def deserialize(self, data):
+ fields = msgpack.unpackb(data, raw=False)
+ if fields[0] == _RECORD_TYPE_ID:
+ stream, value = fields[1:]
+ record = message.Record(value)
+ record.stream = stream
+ return record
+ elif fields[0] == _KEY_RECORD_TYPE_ID:
+ stream, key, value = fields[1:]
+ key_record = message.KeyRecord(key, value)
+ key_record.stream = stream
+ return key_record
+ else:
+ raise Exception("Unsupported type id {}, type {}".format(
+ fields[0], type(fields[0])))
diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py
index ee0aeb561..c207c4727 100644
--- a/streaming/python/runtime/task.py
+++ b/streaming/python/runtime/task.py
@@ -1,11 +1,13 @@
import logging
-import pickle
import threading
from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
+from ray.streaming.runtime import serialization
+from ray.streaming.runtime.serialization import \
+ PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
@@ -38,36 +40,40 @@ class StreamTask(ABC):
# writers
collectors = []
for edge in execution_node.output_edges:
- output_actor_ids = {}
+ output_actors_map = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time())
- output_actor_ids[channel_name] = target_actor
- if len(output_actor_ids) > 0:
- channel_ids = list(output_actor_ids.keys())
- to_actor_ids = list(output_actor_ids.values())
- writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
- logger.info("Create DataWriter succeed.")
+ output_actors_map[channel_name] = target_actor
+ if len(output_actors_map) > 0:
+ channel_ids = list(output_actors_map.keys())
+ target_actors = list(output_actors_map.values())
+ logger.info(
+ "Create DataWriter channel_ids {}, target_actors {}."
+ .format(channel_ids, target_actors))
+ writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
- OutputCollector(channel_ids, writer, edge.partition))
+ OutputCollector(writer, channel_ids, target_actors,
+ edge.partition))
# readers
- input_actor_ids = {}
+ input_actor_map = {}
for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time())
- input_actor_ids[channel_name] = src_actor
- if len(input_actor_ids) > 0:
- channel_ids = list(input_actor_ids.keys())
- from_actor_ids = list(input_actor_ids.values())
- logger.info("Create DataReader, channels {}.".format(channel_ids))
- self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
+ input_actor_map[channel_name] = src_actor
+ if len(input_actor_map) > 0:
+ channel_ids = list(input_actor_map.keys())
+ from_actors = list(input_actor_map.values())
+ logger.info("Create DataReader, channels {}, input_actors {}."
+ .format(channel_ids, from_actors))
+ self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
@@ -111,6 +117,8 @@ class InputStreamTask(StreamTask):
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
+ self.python_serializer = PythonSerializer()
+ self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
@@ -120,7 +128,11 @@ class InputStreamTask(StreamTask):
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
- msg = pickle.loads(msg_data)
+ type_id = msg_data[:1]
+ if (type_id == serialization._PYTHON_TYPE_ID):
+ msg = self.python_serializer.deserialize(msg_data[1:])
+ else:
+ msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True
diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py
index f40ea087a..a6beb03de 100644
--- a/streaming/python/runtime/transfer.py
+++ b/streaming/python/runtime/transfer.py
@@ -147,13 +147,17 @@ class ChannelCreationParametersBuilder:
wrap initial parameters needed by a streaming queue
"""
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
- "io.ray.streaming.runtime.worker", "onReaderMessage", "([B)V")
+ "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
+ "([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
- "io.ray.streaming.runtime.worker", "onReaderMessageSync", "([B)[B")
+ "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
+ "([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
- "io.ray.streaming.runtime.worker", "onWriterMessage", "([B)V")
+ "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
+ "([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
- "io.ray.streaming.runtime.worker", "onWriterMessageSync", "([B)[B")
+ "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
+ "([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py
index 9743205ef..86e88a0f7 100644
--- a/streaming/python/runtime/worker.py
+++ b/streaming/python/runtime/worker.py
@@ -10,6 +10,9 @@ from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__)
+# special flag to indicate this actor not ready
+_NOT_READY_FLAG_ = b" " * 4
+
@ray.remote
class JobWorker(object):
@@ -66,23 +69,31 @@ class JobWorker(object):
type(self.stream_processor))
def on_reader_message(self, buffer: bytes):
- """used in direct call mode"""
+ """Called by upstream queue writer to send data message to downstream
+ queue reader.
+ """
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
- """used in direct call mode"""
+ """Called by upstream queue writer to send control message to downstream
+ downstream queue reader.
+ """
if self.reader_client is None:
- return b" " * 4 # special flag to indicate this actor not ready
+ return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
- """used in direct call mode"""
+ """Called by downstream queue reader to send notify message to
+ upstream queue writer.
+ """
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
- """used in direct call mode"""
+ """Called by downstream queue reader to send control message to
+ upstream queue writer.
+ """
if self.writer_client is None:
- return b" " * 4 # special flag to indicate this actor not ready
+ return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
diff --git a/streaming/python/tests/test_function.py b/streaming/python/tests/test_function.py
index 3564a1698..c9ce33067 100644
--- a/streaming/python/tests/test_function.py
+++ b/streaming/python/tests/test_function.py
@@ -14,9 +14,9 @@ class MapFunc(function.MapFunction):
def test_load_function():
- # function_bytes, module_name, class_name, function_name,
+ # function_bytes, module_name, function_name/class_name,
# function_interface
descriptor_func_bytes = gateway_client.serialize(
- [None, __name__, MapFunc.__name__, None, "MapFunction"])
+ [None, __name__, MapFunc.__name__, "MapFunction"])
func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc
diff --git a/streaming/python/tests/test_hybrid_stream.py b/streaming/python/tests/test_hybrid_stream.py
new file mode 100644
index 000000000..a103be435
--- /dev/null
+++ b/streaming/python/tests/test_hybrid_stream.py
@@ -0,0 +1,70 @@
+import json
+import ray
+from ray.streaming import StreamingContext
+import subprocess
+import os
+
+
+def map_func1(x):
+ print("HybridStreamTest map_func1", x)
+ return str(x)
+
+
+def filter_func1(x):
+ print("HybridStreamTest filter_func1", x)
+ return "b" not in x
+
+
+def sink_func1(x):
+ print("HybridStreamTest sink_func1 value:", x)
+
+
+def test_hybrid_stream():
+ subprocess.check_call(
+ ["bazel", "build", "//streaming/java:all_streaming_tests_deploy.jar"])
+ current_dir = os.path.abspath(os.path.dirname(__file__))
+ jar_path = os.path.join(
+ current_dir,
+ "../../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar")
+ jar_path = os.path.abspath(jar_path)
+ print("jar_path", jar_path)
+ java_worker_options = json.dumps(["-classpath", jar_path])
+ print("java_worker_options", java_worker_options)
+ assert not ray.is_initialized()
+ ray.init(
+ load_code_from_local=True,
+ include_java=True,
+ java_worker_options=java_worker_options,
+ _internal_config=json.dumps({
+ "num_workers_per_process_java": 1
+ }))
+
+ sink_file = "/tmp/ray_streaming_test_hybrid_stream.txt"
+ if os.path.exists(sink_file):
+ os.remove(sink_file)
+
+ def sink_func(x):
+ print("HybridStreamTest", x)
+ with open(sink_file, "a") as f:
+ f.write(str(x))
+
+ ctx = StreamingContext.Builder().build()
+ ctx.from_values("a", "b", "c") \
+ .as_java_stream() \
+ .map("io.ray.streaming.runtime.demo.HybridStreamTest$Mapper1") \
+ .filter("io.ray.streaming.runtime.demo.HybridStreamTest$Filter1") \
+ .as_python_stream() \
+ .sink(sink_func)
+ ctx.submit("HybridStreamTest")
+ import time
+ time.sleep(3)
+ ray.shutdown()
+ with open(sink_file, "r") as f:
+ result = f.read()
+ assert "a" in result
+ assert "b" not in result
+ assert "c" in result
+
+
+if __name__ == "__main__":
+ test_hybrid_stream()
diff --git a/streaming/python/tests/test_serialization.py b/streaming/python/tests/test_serialization.py
new file mode 100644
index 000000000..67865f802
--- /dev/null
+++ b/streaming/python/tests/test_serialization.py
@@ -0,0 +1,13 @@
+from ray.streaming.runtime.serialization import CrossLangSerializer
+from ray.streaming.message import Record, KeyRecord
+
+
+def test_serialize():
+ serializer = CrossLangSerializer()
+ record = Record("value")
+ record.stream = "stream1"
+ key_record = KeyRecord("key", "value")
+ key_record.stream = "stream2"
+ assert record == serializer.deserialize(serializer.serialize(record))
+ assert key_record == serializer.\
+ deserialize(serializer.serialize(key_record))
diff --git a/streaming/python/tests/test_stream.py b/streaming/python/tests/test_stream.py
new file mode 100644
index 000000000..8eb0fbe6a
--- /dev/null
+++ b/streaming/python/tests/test_stream.py
@@ -0,0 +1,31 @@
+import ray
+from ray.streaming import StreamingContext
+
+
+def test_data_stream():
+ ray.init(load_code_from_local=True, include_java=True)
+ ctx = StreamingContext.Builder().build()
+ stream = ctx.from_values(1, 2, 3)
+ java_stream = stream.as_java_stream()
+ python_stream = java_stream.as_python_stream()
+ assert stream.get_id() == java_stream.get_id()
+ assert stream.get_id() == python_stream.get_id()
+ python_stream.set_parallelism(10)
+ assert stream.get_parallelism() == java_stream.get_parallelism()
+ assert stream.get_parallelism() == python_stream.get_parallelism()
+ ray.shutdown()
+
+
+def test_key_data_stream():
+ ray.init(load_code_from_local=True, include_java=True)
+ ctx = StreamingContext.Builder().build()
+ key_stream = ctx.from_values(
+ "a", "b", "c").map(lambda x: (x, 1)).key_by(lambda x: x[0])
+ java_stream = key_stream.as_java_stream()
+ python_stream = java_stream.as_python_stream()
+ assert key_stream.get_id() == java_stream.get_id()
+ assert key_stream.get_id() == python_stream.get_id()
+ python_stream.set_parallelism(10)
+ assert key_stream.get_parallelism() == java_stream.get_parallelism()
+ assert key_stream.get_parallelism() == python_stream.get_parallelism()
+ ray.shutdown()
diff --git a/streaming/python/tests/test_word_count.py b/streaming/python/tests/test_word_count.py
index d86595cf4..03d9d7652 100644
--- a/streaming/python/tests/test_word_count.py
+++ b/streaming/python/tests/test_word_count.py
@@ -32,7 +32,9 @@ def test_simple_word_count():
def sink_func(x):
with open(sink_file, "a") as f:
- f.write("{}:{},".format(x[0], x[1]))
+ line = "{}:{},".format(x[0], x[1])
+ print("sink_func", line)
+ f.write(line)
ctx.from_values("a", "b", "c") \
.set_parallelism(1) \
diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc
index e2cb2e861..98acc36a1 100644
--- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc
+++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.cc
@@ -26,6 +26,13 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
return reinterpret_cast(reader_client);
}
+JNIEXPORT void JNICALL
+Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
+ JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
+ auto *writer_client = reinterpret_cast(ptr);
+ writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
+}
+
JNIEXPORT jbyteArray JNICALL
Java_io_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {