Streaming rich function (#8602)

This commit is contained in:
chaokunyang
2020-05-27 18:36:07 +08:00
committed by GitHub
parent bd4fbcd7fc
commit bcdbe2d3d4
20 changed files with 264 additions and 71 deletions
@@ -20,14 +20,20 @@ public interface RuntimeContext {
int getParallelism();
/**
* @return config of current function
*/
Map<String, String> getConfig();
/**
* @return config of the job
*/
Map<String, String> getJobConfig();
Long getCheckpointId();
void setCheckpointId(long checkpointId);
Long getMaxBatch();
Map<String, String> getConfig();
void setCurrentKey(Object key);
KeyStateBackend getKeyStateBackend();
@@ -0,0 +1,24 @@
package io.ray.streaming.api.function;
import io.ray.streaming.api.context.RuntimeContext;
/**
* An interface for all user-defined functions to define the life cycle methods of the
* functions, and access the task context where the functions get executed.
*/
public interface RichFunction extends Function {
/**
* Initialization method for user function which called before the first call to the user
* function.
* @param runtimeContext runtime context
*/
void open(RuntimeContext runtimeContext);
/**
* Tear-down method for the user function which called after the last call to
* the user function.
*/
void close();
}
@@ -0,0 +1,40 @@
package io.ray.streaming.api.function.internal;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.api.function.RichFunction;
/**
* A util class for {@link Function}
*/
public class Functions {
private static class DefaultRichFunction implements RichFunction {
private final Function function;
private DefaultRichFunction(Function function) {
this.function = function;
}
@Override
public void open(RuntimeContext runtimeContext) {
}
@Override
public void close() {
}
public Function getFunction() {
return function;
}
}
public static RichFunction wrap(Function function) {
if (function instanceof RichFunction) {
return (RichFunction) function;
} else {
return new DefaultRichFunction(function);
}
}
}
@@ -9,6 +9,8 @@ import io.ray.streaming.operator.Operator;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.python.PythonPartition;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* Abstract base class of all stream types.
@@ -23,6 +25,7 @@ public abstract class Stream<S extends Stream<S, T>, T>
private final Stream inputStream;
private final StreamOperator operator;
private int parallelism = 1;
private Map<String, String> config = new HashMap<>();
private Partition<T> partition;
private Stream originalStream;
@@ -134,6 +137,25 @@ public abstract class Stream<S extends Stream<S, T>, T>
return self();
}
public S withConfig(Map<String, String> config) {
config.forEach(this::withConfig);
return self();
}
public S withConfig(String key, String value) {
if (isProxyStream()) {
originalStream.withConfig(key, value);
} else {
this.config.put(key, value);
}
return self();
}
@SuppressWarnings("unchecked")
public Map<String, String> getConfig() {
return isProxyStream() ? originalStream.getConfig() : config;
}
public boolean isProxyStream() {
return originalStream != null;
}
@@ -77,6 +77,7 @@ public class JobGraphBuilder {
} else {
throw new UnsupportedOperationException("Unsupported stream: " + stream);
}
jobVertex.setConfig(stream.getConfig());
this.jobGraph.addVertex(jobVertex);
}
@@ -4,20 +4,23 @@ import com.google.common.base.MoreObjects;
import io.ray.streaming.api.Language;
import io.ray.streaming.operator.StreamOperator;
import java.io.Serializable;
import java.util.Map;
/**
* Job vertex is a cell node where logic is executed.
*/
public class JobVertex implements Serializable {
private int vertexId;
private int parallelism;
private VertexType vertexType;
private Language language;
private StreamOperator streamOperator;
private Map<String, String> config;
public JobVertex(int vertexId, int parallelism, VertexType vertexType,
StreamOperator streamOperator) {
public JobVertex(int vertexId,
int parallelism,
VertexType vertexType,
StreamOperator streamOperator) {
this.vertexId = vertexId;
this.parallelism = parallelism;
this.vertexType = vertexType;
@@ -45,6 +48,14 @@ public class JobVertex implements Serializable {
return language;
}
public Map<String, String> getConfig() {
return config;
}
public void setConfig(Map<String, String> config) {
this.config = config;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
@@ -53,7 +64,7 @@ public class JobVertex implements Serializable {
.add("vertexType", vertexType)
.add("language", language)
.add("streamOperator", streamOperator)
.add("config", config)
.toString();
}
}
@@ -4,25 +4,30 @@ import io.ray.streaming.api.Language;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import io.ray.streaming.api.function.RichFunction;
import io.ray.streaming.api.function.internal.Functions;
import io.ray.streaming.message.KeyRecord;
import io.ray.streaming.message.Record;
import java.util.List;
public abstract class StreamOperator<F extends Function> implements Operator {
protected String name;
protected F function;
protected final String name;
protected final F function;
protected final RichFunction richFunction;
protected List<Collector> collectorList;
protected RuntimeContext runtimeContext;
public StreamOperator(F function) {
this.name = getClass().getSimpleName();
this.function = function;
this.richFunction = Functions.wrap(function);
}
@Override
public void open(List<Collector> collectorList, RuntimeContext runtimeContext) {
this.collectorList = collectorList;
this.runtimeContext = runtimeContext;
richFunction.open(runtimeContext);
}
@Override
@@ -32,7 +37,7 @@ public abstract class StreamOperator<F extends Function> implements Operator {
@Override
public void close() {
richFunction.close();
}
@Override
@@ -2,17 +2,6 @@ package io.ray.streaming.util;
public class Config {
/**
* Maximum number of batches to run in a streaming job.
*/
public static final String STREAMING_BATCH_MAX_COUNT = "streaming.batch.max.count";
/**
* batch frequency in milliseconds
*/
public static final String STREAMING_BATCH_FREQUENCY = "streaming.batch.frequency";
public static final long STREAMING_BATCH_FREQUENCY_DEFAULT = 1000;
public static final String STREAMING_JOB_NAME = "streaming.job.name";
public static final String STREAMING_OP_NAME = "streaming.op_name";
public static final String STREAMING_WORKER_NAME = "streaming.worker_name";
@@ -6,6 +6,7 @@ import io.ray.streaming.operator.StreamOperator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* A node in the physical execution graph.
@@ -13,15 +14,17 @@ import java.util.List;
public class ExecutionNode implements Serializable {
private int nodeId;
private int parallelism;
private Map<String, String> config;
private NodeType nodeType;
private StreamOperator streamOperator;
private List<ExecutionTask> executionTasks;
private List<ExecutionEdge> inputsEdges;
private List<ExecutionEdge> outputEdges;
public ExecutionNode(int nodeId, int parallelism) {
public ExecutionNode(int nodeId, int parallelism, Map<String, String> config) {
this.nodeId = nodeId;
this.parallelism = parallelism;
this.config = config;
this.executionTasks = new ArrayList<>();
this.inputsEdges = new ArrayList<>();
this.outputEdges = new ArrayList<>();
@@ -43,6 +46,10 @@ public class ExecutionNode implements Serializable {
this.parallelism = parallelism;
}
public Map<String, String> getConfig() {
return config;
}
public List<ExecutionTask> getExecutionTasks() {
return executionTasks;
}
@@ -31,13 +31,12 @@ import org.slf4j.LoggerFactory;
public class PythonGateway {
private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class);
private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__";
private static MsgPackSerializer serializer = new MsgPackSerializer();
private MsgPackSerializer serializer;
private Map<String, Object> referenceMap;
private StreamingContext streamingContext;
public PythonGateway() {
serializer = new MsgPackSerializer();
referenceMap = new HashMap<>();
LOG.info("PythonGateway created");
}
@@ -156,8 +155,23 @@ public class PythonGateway {
.map((Function<Class, Class>) Primitives::unwrap)
.toArray(Class[]::new);
Optional<Method> any = methods.stream()
.filter(m -> Arrays.equals(m.getParameterTypes(), paramsTypes) ||
Arrays.equals(m.getParameterTypes(), unwrappedTypes))
.filter(m -> {
boolean exactMatch = Arrays.equals(m.getParameterTypes(), paramsTypes) ||
Arrays.equals(m.getParameterTypes(), unwrappedTypes);
if (exactMatch) {
return true;
} else if (paramsTypes.length == m.getParameterTypes().length) {
for (int i = 0; i < m.getParameterTypes().length; i++) {
Class<?> parameterType = m.getParameterTypes()[i];
if (!parameterType.isAssignableFrom(paramsTypes[i])) {
return false;
}
}
return true;
} else {
return false;
}
})
.findAny();
Preconditions.checkArgument(any.isPresent(),
String.format("Method %s with type %s doesn't exist on class %s",
@@ -166,7 +180,21 @@ public class PythonGateway {
}
private static boolean returnReference(Object value) {
return !(value instanceof Number) && !(value instanceof String) && !(value instanceof byte[]);
if (isBasic(value)) {
return false;
} else {
try {
serializer.serialize(value);
return false;
} catch (Exception e) {
return true;
}
}
}
private static boolean isBasic(Object value) {
return value == null || (value instanceof Boolean) || (value instanceof Number) ||
(value instanceof String) || (value instanceof byte[]);
}
public byte[] newInstance(byte[] classNameBytes) {
@@ -38,7 +38,7 @@ public class TaskAssignerImpl implements TaskAssigner {
Map<Integer, ExecutionNode> idToExecutionNode = new HashMap<>();
for (JobVertex jobVertex : jobVertices) {
ExecutionNode executionNode = new ExecutionNode(jobVertex.getVertexId(),
jobVertex.getParallelism());
jobVertex.getParallelism(), jobVertex.getConfig());
executionNode.setNodeType(jobVertex.getVertexType());
List<ExecutionTask> vertexTasks = new ArrayList<>();
for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) {
@@ -1,7 +1,5 @@
package io.ray.streaming.runtime.worker.context;
import static io.ray.streaming.util.Config.STREAMING_BATCH_MAX_COUNT;
import com.google.common.base.Preconditions;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.runtime.core.graph.ExecutionTask;
@@ -21,8 +19,6 @@ import java.util.Map;
* Use Ray to implement RuntimeContext.
*/
public class RayRuntimeContext implements RuntimeContext {
private final Long maxBatch;
/**
* Backend for keyed state. This might be empty if we're not on a keyed stream.
*/
@@ -43,11 +39,6 @@ public class RayRuntimeContext implements RuntimeContext {
this.config = config;
this.taskIndex = executionTask.getTaskIndex();
this.parallelism = parallelism;
if (config.containsKey(STREAMING_BATCH_MAX_COUNT)) {
this.maxBatch = Long.valueOf(config.get(STREAMING_BATCH_MAX_COUNT));
} else {
this.maxBatch = Long.MAX_VALUE;
}
}
@Override
@@ -65,6 +56,16 @@ public class RayRuntimeContext implements RuntimeContext {
return parallelism;
}
@Override
public Map<String, String> getConfig() {
return config;
}
@Override
public Map<String, String> getJobConfig() {
return config;
}
@Override
public Long getCheckpointId() {
return checkpointId;
@@ -81,17 +82,6 @@ public class RayRuntimeContext implements RuntimeContext {
this.checkpointId = checkpointId;
}
@Override
public Long getMaxBatch() {
return maxBatch;
}
@Override
public Map<String, String> getConfig() {
return config;
}
@Override
public void setCurrentKey(Object key) {
this.keyStateBackend.setCurrentKey(key);
@@ -33,7 +33,6 @@ public class WordCountTest extends BaseUnitTest implements Serializable {
Ray.shutdown();
StreamingContext streamingContext = StreamingContext.buildContext();
Map<String, String> config = new HashMap<>();
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
config.put(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
streamingContext.withConfig(config);
List<String> text = new ArrayList<>();
@@ -160,7 +160,6 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
Map<String, Integer> wordCount = new ConcurrentHashMap<>();
StreamingContext streamingContext = StreamingContext.buildContext();
Map<String, String> config = new HashMap<>();
config.put(Config.STREAMING_BATCH_MAX_COUNT, "1");
config.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
config.put(Config.CHANNEL_SIZE, "100000");
streamingContext.withConfig(config);