mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 16:31:16 +08:00
Streaming rich function (#8602)
This commit is contained in:
+10
-4
@@ -20,14 +20,20 @@ public interface RuntimeContext {
|
||||
|
||||
int getParallelism();
|
||||
|
||||
/**
|
||||
* @return config of current function
|
||||
*/
|
||||
Map<String, String> getConfig();
|
||||
|
||||
/**
|
||||
* @return config of the job
|
||||
*/
|
||||
Map<String, String> getJobConfig();
|
||||
|
||||
Long getCheckpointId();
|
||||
|
||||
void setCheckpointId(long checkpointId);
|
||||
|
||||
Long getMaxBatch();
|
||||
|
||||
Map<String, String> getConfig();
|
||||
|
||||
void setCurrentKey(Object key);
|
||||
|
||||
KeyStateBackend getKeyStateBackend();
|
||||
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
package io.ray.streaming.api.function;
|
||||
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
|
||||
/**
|
||||
* An interface for all user-defined functions to define the life cycle methods of the
|
||||
* functions, and access the task context where the functions get executed.
|
||||
*/
|
||||
public interface RichFunction extends Function {
|
||||
|
||||
/**
|
||||
* Initialization method for user function which called before the first call to the user
|
||||
* function.
|
||||
* @param runtimeContext runtime context
|
||||
*/
|
||||
void open(RuntimeContext runtimeContext);
|
||||
|
||||
/**
|
||||
* Tear-down method for the user function which called after the last call to
|
||||
* the user function.
|
||||
*/
|
||||
void close();
|
||||
|
||||
}
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
package io.ray.streaming.api.function.internal;
|
||||
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.api.function.RichFunction;
|
||||
|
||||
/**
|
||||
* A util class for {@link Function}
|
||||
*/
|
||||
public class Functions {
|
||||
|
||||
private static class DefaultRichFunction implements RichFunction {
|
||||
private final Function function;
|
||||
|
||||
private DefaultRichFunction(Function function) {
|
||||
this.function = function;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(RuntimeContext runtimeContext) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
}
|
||||
|
||||
public Function getFunction() {
|
||||
return function;
|
||||
}
|
||||
}
|
||||
|
||||
public static RichFunction wrap(Function function) {
|
||||
if (function instanceof RichFunction) {
|
||||
return (RichFunction) function;
|
||||
} else {
|
||||
return new DefaultRichFunction(function);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import io.ray.streaming.operator.Operator;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import io.ray.streaming.python.PythonPartition;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Abstract base class of all stream types.
|
||||
@@ -23,6 +25,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
private final Stream inputStream;
|
||||
private final StreamOperator operator;
|
||||
private int parallelism = 1;
|
||||
private Map<String, String> config = new HashMap<>();
|
||||
private Partition<T> partition;
|
||||
private Stream originalStream;
|
||||
|
||||
@@ -134,6 +137,25 @@ public abstract class Stream<S extends Stream<S, T>, T>
|
||||
return self();
|
||||
}
|
||||
|
||||
public S withConfig(Map<String, String> config) {
|
||||
config.forEach(this::withConfig);
|
||||
return self();
|
||||
}
|
||||
|
||||
public S withConfig(String key, String value) {
|
||||
if (isProxyStream()) {
|
||||
originalStream.withConfig(key, value);
|
||||
} else {
|
||||
this.config.put(key, value);
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Map<String, String> getConfig() {
|
||||
return isProxyStream() ? originalStream.getConfig() : config;
|
||||
}
|
||||
|
||||
public boolean isProxyStream() {
|
||||
return originalStream != null;
|
||||
}
|
||||
|
||||
+1
@@ -77,6 +77,7 @@ public class JobGraphBuilder {
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported stream: " + stream);
|
||||
}
|
||||
jobVertex.setConfig(stream.getConfig());
|
||||
this.jobGraph.addVertex(jobVertex);
|
||||
}
|
||||
|
||||
|
||||
+15
-4
@@ -4,20 +4,23 @@ import com.google.common.base.MoreObjects;
|
||||
import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.operator.StreamOperator;
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Job vertex is a cell node where logic is executed.
|
||||
*/
|
||||
public class JobVertex implements Serializable {
|
||||
|
||||
private int vertexId;
|
||||
private int parallelism;
|
||||
private VertexType vertexType;
|
||||
private Language language;
|
||||
private StreamOperator streamOperator;
|
||||
private Map<String, String> config;
|
||||
|
||||
public JobVertex(int vertexId, int parallelism, VertexType vertexType,
|
||||
StreamOperator streamOperator) {
|
||||
public JobVertex(int vertexId,
|
||||
int parallelism,
|
||||
VertexType vertexType,
|
||||
StreamOperator streamOperator) {
|
||||
this.vertexId = vertexId;
|
||||
this.parallelism = parallelism;
|
||||
this.vertexType = vertexType;
|
||||
@@ -45,6 +48,14 @@ public class JobVertex implements Serializable {
|
||||
return language;
|
||||
}
|
||||
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
public void setConfig(Map<String, String> config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return MoreObjects.toStringHelper(this)
|
||||
@@ -53,7 +64,7 @@ public class JobVertex implements Serializable {
|
||||
.add("vertexType", vertexType)
|
||||
.add("language", language)
|
||||
.add("streamOperator", streamOperator)
|
||||
.add("config", config)
|
||||
.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+8
-3
@@ -4,25 +4,30 @@ import io.ray.streaming.api.Language;
|
||||
import io.ray.streaming.api.collector.Collector;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.api.function.Function;
|
||||
import io.ray.streaming.api.function.RichFunction;
|
||||
import io.ray.streaming.api.function.internal.Functions;
|
||||
import io.ray.streaming.message.KeyRecord;
|
||||
import io.ray.streaming.message.Record;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
protected String name;
|
||||
protected F function;
|
||||
protected final String name;
|
||||
protected final F function;
|
||||
protected final RichFunction richFunction;
|
||||
protected List<Collector> collectorList;
|
||||
protected RuntimeContext runtimeContext;
|
||||
|
||||
public StreamOperator(F function) {
|
||||
this.name = getClass().getSimpleName();
|
||||
this.function = function;
|
||||
this.richFunction = Functions.wrap(function);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
|
||||
this.collectorList = collectorList;
|
||||
this.runtimeContext = runtimeContext;
|
||||
richFunction.open(runtimeContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -32,7 +37,7 @@ public abstract class StreamOperator<F extends Function> implements Operator {
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
richFunction.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -2,17 +2,6 @@ package io.ray.streaming.util;
|
||||
|
||||
public class Config {
|
||||
|
||||
/**
|
||||
* Maximum number of batches to run in a streaming job.
|
||||
*/
|
||||
public static final String STREAMING_BATCH_MAX_COUNT = "streaming.batch.max.count";
|
||||
|
||||
/**
|
||||
* batch frequency in milliseconds
|
||||
*/
|
||||
public static final String STREAMING_BATCH_FREQUENCY = "streaming.batch.frequency";
|
||||
public static final long STREAMING_BATCH_FREQUENCY_DEFAULT = 1000;
|
||||
|
||||
public static final String STREAMING_JOB_NAME = "streaming.job.name";
|
||||
public static final String STREAMING_OP_NAME = "streaming.op_name";
|
||||
public static final String STREAMING_WORKER_NAME = "streaming.worker_name";
|
||||
|
||||
+8
-1
@@ -6,6 +6,7 @@ import io.ray.streaming.operator.StreamOperator;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A node in the physical execution graph.
|
||||
@@ -13,15 +14,17 @@ import java.util.List;
|
||||
public class ExecutionNode implements Serializable {
|
||||
private int nodeId;
|
||||
private int parallelism;
|
||||
private Map<String, String> config;
|
||||
private NodeType nodeType;
|
||||
private StreamOperator streamOperator;
|
||||
private List<ExecutionTask> executionTasks;
|
||||
private List<ExecutionEdge> inputsEdges;
|
||||
private List<ExecutionEdge> outputEdges;
|
||||
|
||||
public ExecutionNode(int nodeId, int parallelism) {
|
||||
public ExecutionNode(int nodeId, int parallelism, Map<String, String> config) {
|
||||
this.nodeId = nodeId;
|
||||
this.parallelism = parallelism;
|
||||
this.config = config;
|
||||
this.executionTasks = new ArrayList<>();
|
||||
this.inputsEdges = new ArrayList<>();
|
||||
this.outputEdges = new ArrayList<>();
|
||||
@@ -43,6 +46,10 @@ public class ExecutionNode implements Serializable {
|
||||
this.parallelism = parallelism;
|
||||
}
|
||||
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
public List<ExecutionTask> getExecutionTasks() {
|
||||
return executionTasks;
|
||||
}
|
||||
|
||||
+33
-5
@@ -31,13 +31,12 @@ import org.slf4j.LoggerFactory;
|
||||
public class PythonGateway {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class);
|
||||
private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__";
|
||||
private static MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
|
||||
private MsgPackSerializer serializer;
|
||||
private Map<String, Object> referenceMap;
|
||||
private StreamingContext streamingContext;
|
||||
|
||||
public PythonGateway() {
|
||||
serializer = new MsgPackSerializer();
|
||||
referenceMap = new HashMap<>();
|
||||
LOG.info("PythonGateway created");
|
||||
}
|
||||
@@ -156,8 +155,23 @@ public class PythonGateway {
|
||||
.map((Function<Class, Class>) Primitives::unwrap)
|
||||
.toArray(Class[]::new);
|
||||
Optional<Method> any = methods.stream()
|
||||
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
|
||||
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
|
||||
.filter(m -> {
|
||||
boolean exactMatch = Arrays.equals(m.getParameterTypes(), paramsTypes) ||
|
||||
Arrays.equals(m.getParameterTypes(), unwrappedTypes);
|
||||
if (exactMatch) {
|
||||
return true;
|
||||
} else if (paramsTypes.length == m.getParameterTypes().length) {
|
||||
for (int i = 0; i < m.getParameterTypes().length; i++) {
|
||||
Class<?> parameterType = m.getParameterTypes()[i];
|
||||
if (!parameterType.isAssignableFrom(paramsTypes[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
})
|
||||
.findAny();
|
||||
Preconditions.checkArgument(any.isPresent(),
|
||||
String.format("Method %s with type %s doesn't exist on class %s",
|
||||
@@ -166,7 +180,21 @@ public class PythonGateway {
|
||||
}
|
||||
|
||||
private static boolean returnReference(Object value) {
|
||||
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
|
||||
if (isBasic(value)) {
|
||||
return false;
|
||||
} else {
|
||||
try {
|
||||
serializer.serialize(value);
|
||||
return false;
|
||||
} catch (Exception e) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isBasic(Object value) {
|
||||
return value == null || (value instanceof Boolean) || (value instanceof Number) ||
|
||||
(value instanceof String) || (value instanceof byte[]);
|
||||
}
|
||||
|
||||
public byte[] newInstance(byte[] classNameBytes) {
|
||||
|
||||
+1
-1
@@ -38,7 +38,7 @@ public class TaskAssignerImpl implements TaskAssigner {
|
||||
Map<Integer, ExecutionNode> idToExecutionNode = new HashMap<>();
|
||||
for (JobVertex jobVertex : jobVertices) {
|
||||
ExecutionNode executionNode = new ExecutionNode(jobVertex.getVertexId(),
|
||||
jobVertex.getParallelism());
|
||||
jobVertex.getParallelism(), jobVertex.getConfig());
|
||||
executionNode.setNodeType(jobVertex.getVertexType());
|
||||
List<ExecutionTask> vertexTasks = new ArrayList<>();
|
||||
for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) {
|
||||
|
||||
+10
-20
@@ -1,7 +1,5 @@
|
||||
package io.ray.streaming.runtime.worker.context;
|
||||
|
||||
import static io.ray.streaming.util.Config.STREAMING_BATCH_MAX_COUNT;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.streaming.api.context.RuntimeContext;
|
||||
import io.ray.streaming.runtime.core.graph.ExecutionTask;
|
||||
@@ -21,8 +19,6 @@ import java.util.Map;
|
||||
* Use Ray to implement RuntimeContext.
|
||||
*/
|
||||
public class RayRuntimeContext implements RuntimeContext {
|
||||
|
||||
private final Long maxBatch;
|
||||
/**
|
||||
* Backend for keyed state. This might be empty if we're not on a keyed stream.
|
||||
*/
|
||||
@@ -43,11 +39,6 @@ public class RayRuntimeContext implements RuntimeContext {
|
||||
this.config = config;
|
||||
this.taskIndex = executionTask.getTaskIndex();
|
||||
this.parallelism = parallelism;
|
||||
if (config.containsKey(STREAMING_BATCH_MAX_COUNT)) {
|
||||
this.maxBatch = Long.valueOf(config.get(STREAMING_BATCH_MAX_COUNT));
|
||||
} else {
|
||||
this.maxBatch = Long.MAX_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -65,6 +56,16 @@ public class RayRuntimeContext implements RuntimeContext {
|
||||
return parallelism;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> getJobConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCheckpointId() {
|
||||
return checkpointId;
|
||||
@@ -81,17 +82,6 @@ public class RayRuntimeContext implements RuntimeContext {
|
||||
this.checkpointId = checkpointId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getMaxBatch() {
|
||||
return maxBatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setCurrentKey(Object key) {
|
||||
this.keyStateBackend.setCurrentKey(key);
|
||||
|
||||
-1
@@ -33,7 +33,6 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
|
||||
Ray.shutdown();
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
Map<String, String> config = new HashMap<>();
|
||||
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
|
||||
config.put(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
|
||||
streamingContext.withConfig(config);
|
||||
List<String> text = new ArrayList<>();
|
||||
|
||||
-1
@@ -160,7 +160,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
|
||||
Map<String, Integer> wordCount = new ConcurrentHashMap<>();
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
Map<String, String> config = new HashMap<>();
|
||||
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
|
||||
config.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
|
||||
config.put(Config.CHANNEL_SIZE, "100000");
|
||||
streamingContext.withConfig(config);
|
||||
|
||||
@@ -151,12 +151,30 @@ class RuntimeContext(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self):
|
||||
"""
|
||||
Returns:
|
||||
The config with which the parallel task runs.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_job_config(self):
|
||||
"""
|
||||
Returns:
|
||||
The job config.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RuntimeContextImpl(RuntimeContext):
|
||||
def __init__(self, task_id, task_index, parallelism):
|
||||
def __init__(self, task_id, task_index, parallelism, **kargs):
|
||||
self.task_id = task_id
|
||||
self.task_index = task_index
|
||||
self.parallelism = parallelism
|
||||
self.config = kargs.get("config", {})
|
||||
self.job_config = kargs.get("job_config", {})
|
||||
|
||||
def get_task_id(self):
|
||||
return self.task_id
|
||||
@@ -166,3 +184,9 @@ class RuntimeContextImpl(RuntimeContext):
|
||||
|
||||
def get_parallelism(self):
|
||||
return self.parallelism
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
def get_job_config(self):
|
||||
return self.job_config
|
||||
|
||||
@@ -59,6 +59,38 @@ class Stream(ABC):
|
||||
return self._gateway_client(). \
|
||||
call_method(self._j_stream, "getId")
|
||||
|
||||
def with_config(self, key=None, value=None, conf=None):
|
||||
"""Set stream config.
|
||||
|
||||
Args:
|
||||
key: a key name string for configuration property
|
||||
value: a value string for configuration property
|
||||
conf: multi key-value pairs as a dict
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
if key is not None:
|
||||
assert value
|
||||
assert type(key) is str
|
||||
assert type(value) is str
|
||||
self._gateway_client(). \
|
||||
call_method(self._j_stream, "withConfig", key, value)
|
||||
if conf is not None:
|
||||
for k, v in conf.items():
|
||||
assert type(k) is str
|
||||
assert type(v) is str
|
||||
self._gateway_client(). \
|
||||
call_method(self._j_stream, "withConfig", conf)
|
||||
return self
|
||||
|
||||
def get_config(self):
|
||||
"""
|
||||
Returns:
|
||||
A dict config for this stream
|
||||
"""
|
||||
return self._gateway_client().call_method(self._j_stream, "getConfig")
|
||||
|
||||
@abstractmethod
|
||||
def get_language(self):
|
||||
pass
|
||||
@@ -252,7 +284,7 @@ class JavaDataStream(Stream):
|
||||
"""
|
||||
Represents a stream of data which applies a transformation executed by
|
||||
java. It's also a wrapper of java
|
||||
`org.ray.streaming.api.stream.DataStream`
|
||||
`io.ray.streaming.api.stream.DataStream`
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream, streaming_context=None):
|
||||
@@ -263,39 +295,39 @@ class JavaDataStream(Stream):
|
||||
return function.Language.JAVA
|
||||
|
||||
def map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.map"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.map"""
|
||||
return JavaDataStream(self, self._unary_call("map", java_func_class))
|
||||
|
||||
def flat_map(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.flatMap"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.flatMap"""
|
||||
return JavaDataStream(self, self._unary_call("flatMap",
|
||||
java_func_class))
|
||||
|
||||
def filter(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.filter"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.filter"""
|
||||
return JavaDataStream(self, self._unary_call("filter",
|
||||
java_func_class))
|
||||
|
||||
def key_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.keyBy"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.keyBy"""
|
||||
self._check_partition_call()
|
||||
return JavaKeyDataStream(self,
|
||||
self._unary_call("keyBy", java_func_class))
|
||||
|
||||
def broadcast(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.broadcast"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.broadcast"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("broadcast", java_func_class))
|
||||
|
||||
def partition_by(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.partitionBy"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.partitionBy"""
|
||||
self._check_partition_call()
|
||||
return JavaDataStream(self,
|
||||
self._unary_call("partitionBy", java_func_class))
|
||||
|
||||
def sink(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.DataStream.sink"""
|
||||
"""See io.ray.streaming.api.stream.DataStream.sink"""
|
||||
return JavaStreamSink(self, self._unary_call("sink", java_func_class))
|
||||
|
||||
def as_python_stream(self):
|
||||
@@ -374,14 +406,14 @@ class KeyDataStream(DataStream):
|
||||
class JavaKeyDataStream(JavaDataStream):
|
||||
"""
|
||||
Represents a DataStream returned by a key-by operation in java.
|
||||
Wrapper of org.ray.streaming.api.stream.KeyDataStream
|
||||
Wrapper of io.ray.streaming.api.stream.KeyDataStream
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
super().__init__(input_stream, j_stream)
|
||||
|
||||
def reduce(self, java_func_class):
|
||||
"""See org.ray.streaming.api.stream.KeyDataStream.reduce"""
|
||||
"""See io.ray.streaming.api.stream.KeyDataStream.reduce"""
|
||||
return JavaDataStream(self,
|
||||
super()._unary_call("reduce", java_func_class))
|
||||
|
||||
@@ -425,7 +457,7 @@ class StreamSource(DataStream):
|
||||
|
||||
class JavaStreamSource(JavaDataStream):
|
||||
"""Represents a source of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.DataStreamSource
|
||||
Wrapper of java io.ray.streaming.api.stream.DataStreamSource
|
||||
"""
|
||||
|
||||
def __init__(self, j_stream, streaming_context):
|
||||
@@ -446,7 +478,7 @@ class JavaStreamSource(JavaDataStream):
|
||||
j_func = streaming_context._gateway_client() \
|
||||
.new_instance(java_source_func_class)
|
||||
j_stream = streaming_context._gateway_client() \
|
||||
.call_function("org.ray.streaming.api.stream.DataStreamSource"
|
||||
.call_function("io.ray.streaming.api.stream.DataStreamSource"
|
||||
"fromSource", streaming_context._j_ctx, j_func)
|
||||
return JavaStreamSource(j_stream, streaming_context)
|
||||
|
||||
@@ -465,7 +497,7 @@ class StreamSink(Stream):
|
||||
|
||||
class JavaStreamSink(Stream):
|
||||
"""Represents a sink of the java DataStream.
|
||||
Wrapper of java org.ray.streaming.api.stream.StreamSink
|
||||
Wrapper of java io.ray.streaming.api.stream.StreamSink
|
||||
"""
|
||||
|
||||
def __init__(self, input_stream, j_stream):
|
||||
|
||||
@@ -2,7 +2,6 @@ import enum
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import cloudpickle
|
||||
@@ -17,7 +16,7 @@ class Language(enum.Enum):
|
||||
class Function(ABC):
|
||||
"""The base interface for all user-defined functions."""
|
||||
|
||||
def open(self, conf: typing.Dict[str, str]):
|
||||
def open(self, runtime_context):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
@@ -55,9 +54,6 @@ class SourceFunction(Function):
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class MapFunction(Function):
|
||||
"""
|
||||
|
||||
@@ -71,12 +71,13 @@ class StreamOperator(Operator, ABC):
|
||||
def open(self, collectors, runtime_context):
|
||||
self.collectors = collectors
|
||||
self.runtime_context = runtime_context
|
||||
self.func.open(runtime_context)
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
self.func.close()
|
||||
|
||||
def collect(self, record):
|
||||
for collector in self.collectors:
|
||||
|
||||
@@ -83,6 +83,7 @@ class StreamTask(ABC):
|
||||
import atexit
|
||||
atexit.register(exit_handler)
|
||||
|
||||
# TODO(chaokunyang) add task/job config
|
||||
runtime_context = RuntimeContextImpl(
|
||||
self.worker.execution_task.task_id,
|
||||
self.worker.execution_task.task_index, execution_node.parallelism)
|
||||
|
||||
@@ -29,3 +29,21 @@ def test_key_data_stream():
|
||||
assert key_stream.get_parallelism() == java_stream.get_parallelism()
|
||||
assert key_stream.get_parallelism() == python_stream.get_parallelism()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_stream_config():
|
||||
ray.init(load_code_from_local=True, include_java=True)
|
||||
ctx = StreamingContext.Builder().build()
|
||||
stream = ctx.from_values(1, 2, 3)
|
||||
stream.with_config("k1", "v1")
|
||||
print("config", stream.get_config())
|
||||
assert stream.get_config() == {"k1": "v1"}
|
||||
stream.with_config(conf={"k2": "v2", "k3": "v3"})
|
||||
print("config", stream.get_config())
|
||||
assert stream.get_config() == {"k1": "v1", "k2": "v2", "k3": "v3"}
|
||||
java_stream = stream.as_java_stream()
|
||||
java_stream.with_config(conf={"k4": "v4"})
|
||||
config = java_stream.get_config()
|
||||
print("config", config)
|
||||
assert config == {"k1": "v1", "k2": "v2", "k3": "v3", "k4": "v4"}
|
||||
ray.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user