[Streaming] Streaming Python API (#6755)

This commit is contained in:
chaokunyang
2020-02-25 10:33:33 +08:00
committed by GitHub
parent 2c1f4fd82c
commit 8b6784de06
71 changed files with 2701 additions and 1928 deletions
@@ -1,23 +0,0 @@
package org.ray.streaming.runtime.cluster;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.streaming.runtime.worker.JobWorker;
/**
* Resource-Manager is used to do the management of resources
*/
public class ResourceManager {
public List<RayActor<JobWorker>> createWorkers(int workerNum) {
List<RayActor<JobWorker>> workers = new ArrayList<>();
for (int i = 0; i < workerNum; i++) {
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
workers.add(worker);
}
return workers;
}
}
@@ -7,7 +7,6 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.streaming.runtime.worker.JobWorker;
/**
* Physical execution graph.
@@ -19,19 +18,19 @@ import org.ray.streaming.runtime.worker.JobWorker;
public class ExecutionGraph implements Serializable {
private long buildTime;
private List<ExecutionNode> executionNodeList;
private List<RayActor<JobWorker>> sourceWorkers = new ArrayList<>();
private List<RayActor<JobWorker>> sinkWorkers = new ArrayList<>();
private List<RayActor> sourceWorkers = new ArrayList<>();
private List<RayActor> sinkWorkers = new ArrayList<>();
public ExecutionGraph(List<ExecutionNode> executionNodes) {
this.executionNodeList = executionNodes;
for (ExecutionNode executionNode : executionNodeList) {
if (executionNode.getNodeType() == ExecutionNode.NodeType.SOURCE) {
List<RayActor<JobWorker>> actors = executionNode.getExecutionTasks().stream()
List<RayActor> actors = executionNode.getExecutionTasks().stream()
.map(ExecutionTask::getWorker).collect(Collectors.toList());
sourceWorkers.addAll(actors);
}
if (executionNode.getNodeType() == ExecutionNode.NodeType.SINK) {
List<RayActor<JobWorker>> actors = executionNode.getExecutionTasks().stream()
List<RayActor> actors = executionNode.getExecutionTasks().stream()
.map(ExecutionTask::getWorker).collect(Collectors.toList());
sinkWorkers.addAll(actors);
}
@@ -39,11 +38,11 @@ public class ExecutionGraph implements Serializable {
buildTime = System.currentTimeMillis();
}
public List<RayActor<JobWorker>> getSourceWorkers() {
public List<RayActor> getSourceWorkers() {
return sourceWorkers;
}
public List<RayActor<JobWorker>> getSinkWorkers() {
public List<RayActor> getSinkWorkers() {
return sinkWorkers;
}
@@ -82,10 +81,10 @@ public class ExecutionGraph implements Serializable {
throw new RuntimeException("Task " + taskId + " does not exist!");
}
public Map<Integer, RayActor<JobWorker>> getTaskId2WorkerByNodeId(int nodeId) {
public Map<Integer, RayActor> getTaskId2WorkerByNodeId(int nodeId) {
for (ExecutionNode executionNode : executionNodeList) {
if (executionNode.getNodeId() == nodeId) {
Map<Integer, RayActor<JobWorker>> taskId2Worker = new HashMap<>();
Map<Integer, RayActor> taskId2Worker = new HashMap<>();
for (ExecutionTask executionTask : executionNode.getExecutionTasks()) {
taskId2Worker.put(executionTask.getTaskId(), executionTask.getWorker());
}
@@ -3,6 +3,7 @@ package org.ray.streaming.runtime.core.graph;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.ray.streaming.api.Language;
import org.ray.streaming.jobgraph.VertexType;
import org.ray.streaming.operator.StreamOperator;
@@ -10,7 +11,6 @@ import org.ray.streaming.operator.StreamOperator;
* A node in the physical execution graph.
*/
public class ExecutionNode implements Serializable {
private int nodeId;
private int parallelism;
private NodeType nodeType;
@@ -59,7 +59,7 @@ public class ExecutionNode implements Serializable {
this.outputEdges = outputEdges;
}
public void addExecutionEdge(ExecutionEdge executionEdge) {
public void addOutputEdge(ExecutionEdge executionEdge) {
this.outputEdges.add(executionEdge);
}
@@ -79,6 +79,10 @@ public class ExecutionNode implements Serializable {
this.streamOperator = streamOperator;
}
public Language getLanguage() {
return streamOperator.getLanguage();
}
public NodeType getNodeType() {
return nodeType;
}
@@ -92,7 +96,7 @@ public class ExecutionNode implements Serializable {
this.nodeType = NodeType.SINK;
break;
default:
this.nodeType = NodeType.PROCESS;
this.nodeType = NodeType.TRANSFORM;
}
}
@@ -109,7 +113,7 @@ public class ExecutionNode implements Serializable {
public enum NodeType {
SOURCE,
PROCESS,
TRANSFORM,
SINK,
}
}
@@ -2,7 +2,6 @@ package org.ray.streaming.runtime.core.graph;
import java.io.Serializable;
import org.ray.api.RayActor;
import org.ray.streaming.runtime.worker.JobWorker;
/**
* ExecutionTask is minimal execution unit.
@@ -12,9 +11,9 @@ import org.ray.streaming.runtime.worker.JobWorker;
public class ExecutionTask implements Serializable {
private int taskId;
private int taskIndex;
private RayActor<JobWorker> worker;
private RayActor worker;
public ExecutionTask(int taskId, int taskIndex, RayActor<JobWorker> worker) {
public ExecutionTask(int taskId, int taskIndex, RayActor worker) {
this.taskId = taskId;
this.taskIndex = taskIndex;
this.worker = worker;
@@ -36,11 +35,11 @@ public class ExecutionTask implements Serializable {
this.taskIndex = taskIndex;
}
public RayActor<JobWorker> getWorker() {
public RayActor getWorker() {
return worker;
}
public void setWorker(RayActor<JobWorker> worker) {
public void setWorker(RayActor worker) {
this.worker = worker;
}
}
@@ -0,0 +1,101 @@
package org.ray.streaming.runtime.python;
import com.google.protobuf.ByteString;
import java.util.Arrays;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.streaming.api.function.Function;
import org.ray.streaming.api.partition.Partition;
import org.ray.streaming.python.PythonFunction;
import org.ray.streaming.python.PythonPartition;
import org.ray.streaming.runtime.core.graph.ExecutionEdge;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.core.graph.ExecutionNode;
import org.ray.streaming.runtime.core.graph.ExecutionTask;
import org.ray.streaming.runtime.generated.RemoteCall;
import org.ray.streaming.runtime.generated.Streaming;
public class GraphPbBuilder {
private MsgPackSerializer serializer = new MsgPackSerializer();
/**
* For simple scenario, a single ExecutionNode is enough. But some cases may need
* sub-graph information, so we serialize entire graph.
*/
public RemoteCall.ExecutionGraph buildExecutionGraphPb(ExecutionGraph graph) {
RemoteCall.ExecutionGraph.Builder builder = RemoteCall.ExecutionGraph.newBuilder();
builder.setBuildTime(graph.getBuildTime());
for (ExecutionNode node : graph.getExecutionNodeList()) {
RemoteCall.ExecutionGraph.ExecutionNode.Builder nodeBuilder =
RemoteCall.ExecutionGraph.ExecutionNode.newBuilder();
nodeBuilder.setNodeId(node.getNodeId());
nodeBuilder.setParallelism(node.getParallelism());
nodeBuilder.setNodeType(
Streaming.NodeType.valueOf(node.getNodeType().name()));
nodeBuilder.setLanguage(Streaming.Language.valueOf(node.getLanguage().name()));
byte[] functionBytes = serializeFunction(node.getStreamOperator().getFunction());
nodeBuilder.setFunction(ByteString.copyFrom(functionBytes));
// build tasks
for (ExecutionTask task : node.getExecutionTasks()) {
RemoteCall.ExecutionGraph.ExecutionTask.Builder taskBuilder =
RemoteCall.ExecutionGraph.ExecutionTask.newBuilder();
byte[] serializedActorHandle = ((NativeRayActor) task.getWorker()).toBytes();
taskBuilder
.setTaskId(task.getTaskId())
.setTaskIndex(task.getTaskIndex())
.setWorkerActor(ByteString.copyFrom(serializedActorHandle));
nodeBuilder.addExecutionTasks(taskBuilder.build());
}
// build edges
for (ExecutionEdge edge : node.getInputsEdges()) {
nodeBuilder.addInputEdges(buildEdgePb(edge));
}
for (ExecutionEdge edge : node.getOutputEdges()) {
nodeBuilder.addOutputEdges(buildEdgePb(edge));
}
builder.addExecutionNodes(nodeBuilder.build());
}
return builder.build();
}
private RemoteCall.ExecutionGraph.ExecutionEdge buildEdgePb(ExecutionEdge edge) {
RemoteCall.ExecutionGraph.ExecutionEdge.Builder edgeBuilder =
RemoteCall.ExecutionGraph.ExecutionEdge.newBuilder();
edgeBuilder.setSrcNodeId(edge.getSrcNodeId());
edgeBuilder.setTargetNodeId(edge.getTargetNodeId());
edgeBuilder.setPartition(ByteString.copyFrom(serializePartition(edge.getPartition())));
return edgeBuilder.build();
}
private byte[] serializeFunction(Function function) {
if (function instanceof PythonFunction) {
PythonFunction pyFunc = (PythonFunction) function;
// function_bytes, module_name, class_name, function_name, function_interface
return serializer.serialize(Arrays.asList(
pyFunc.getFunction(), pyFunc.getModuleName(),
pyFunc.getClassName(), pyFunc.getFunctionName(),
pyFunc.getFunctionInterface()
));
} else {
return new byte[0];
}
}
private byte[] serializePartition(Partition partition) {
if (partition instanceof PythonPartition) {
PythonPartition pythonPartition = (PythonPartition) partition;
// partition_bytes, module_name, class_name, function_name
return serializer.serialize(Arrays.asList(
pythonPartition.getPartition(), pythonPartition.getModuleName(),
pythonPartition.getClassName(), pythonPartition.getFunctionName()
));
} else {
return new byte[0];
}
}
}
@@ -0,0 +1,119 @@
package org.ray.streaming.runtime.python;
import com.google.common.io.BaseEncoding;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.msgpack.core.MessageBufferPacker;
import org.msgpack.core.MessagePack;
import org.msgpack.core.MessageUnpacker;
import org.msgpack.value.ArrayValue;
import org.msgpack.value.FloatValue;
import org.msgpack.value.IntegerValue;
import org.msgpack.value.MapValue;
import org.msgpack.value.Value;
public class MsgPackSerializer {
public byte[] serialize(Object obj) {
MessageBufferPacker packer = MessagePack.newDefaultBufferPacker();
serialize(obj, packer);
return packer.toByteArray();
}
private void serialize(Object obj, MessageBufferPacker packer) {
try {
if (obj == null) {
packer.packNil();
} else {
Class<?> clz = obj.getClass();
if (clz == Boolean.class) {
packer.packBoolean((Boolean) obj);
} else if (clz == Integer.class) {
packer.packInt((Integer) obj);
} else if (clz == Long.class) {
packer.packLong((Long) obj);
} else if (clz == Double.class) {
packer.packDouble((Double) obj);
} else if (clz == byte[].class) {
byte[] bytes = (byte[]) obj;
packer.packBinaryHeader(bytes.length);
packer.writePayload(bytes);
} else if (clz == String.class) {
packer.packString((String) obj);
} else if (obj instanceof Collection) {
Collection collection = (Collection) (obj);
packer.packArrayHeader(collection.size());
for (Object o : collection) {
serialize(o, packer);
}
} else if (obj instanceof Map) {
Map map = (Map) (obj);
packer.packMapHeader(map.size());
for (Object o : map.entrySet()) {
Map.Entry e = (Map.Entry) o;
serialize(e.getKey(), packer);
serialize(e.getValue(), packer);
}
} else {
throw new UnsupportedOperationException("Unsupported type " + clz);
}
}
} catch (Exception e) {
throw new RuntimeException("Serialize error for object " + obj, e);
}
}
public Object deserialize(byte[] bytes) {
try {
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bytes);
return deserialize(unpacker.unpackValue());
} catch (Exception e) {
String hex = BaseEncoding.base16().lowerCase().encode(bytes);
throw new RuntimeException("Deserialize error: " + hex, e);
}
}
private Object deserialize(Value value) {
switch (value.getValueType()) {
case NIL:
return null;
case BOOLEAN:
return value.asBooleanValue().getBoolean();
case INTEGER:
IntegerValue iv = value.asIntegerValue();
if (iv.isInIntRange()) {
return iv.toInt();
} else if (iv.isInLongRange()) {
return iv.toLong();
} else {
return iv.toBigInteger();
}
case FLOAT:
FloatValue fv = value.asFloatValue();
return fv.toDouble();
case STRING:
return value.asStringValue().asString();
case BINARY:
return value.asBinaryValue().asByteArray();
case ARRAY:
ArrayValue arrayValue = value.asArrayValue();
List<Object> list = new ArrayList<>(arrayValue.size());
for (Value elem : arrayValue) {
list.add(deserialize(elem));
}
return list;
case MAP:
MapValue mapValue = value.asMapValue();
Map<Object, Object> map = new HashMap<>();
for (Map.Entry<Value, Value> entry : mapValue.entrySet()) {
map.put(deserialize(entry.getKey()), deserialize(entry.getValue()));
}
return map;
default:
throw new UnsupportedOperationException("Unsupported type " + value.getValueType());
}
}
}
@@ -0,0 +1,152 @@
package org.ray.streaming.runtime.python;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.msgpack.core.Preconditions;
import org.ray.api.annotation.RayRemote;
import org.ray.streaming.api.context.StreamingContext;
import org.ray.streaming.python.PythonFunction;
import org.ray.streaming.python.PythonPartition;
import org.ray.streaming.python.stream.PythonStreamSource;
import org.ray.streaming.runtime.util.ReflectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Gateway for streaming python api.
* All calls on DataStream in python will be mapped to DataStream call in java by this
* PythonGateway using ray calls.
* <p>
* Note: this class needs to be in sync with `GatewayClient` in
* `streaming/python/runtime/gateway_client.py`
*/
@SuppressWarnings("unchecked")
@RayRemote
public class PythonGateway {
private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class);
private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__";
private MsgPackSerializer serializer;
private Map<String, Object> referenceMap;
private StreamingContext streamingContext;
public PythonGateway() {
serializer = new MsgPackSerializer();
referenceMap = new HashMap<>();
LOG.info("PythonGateway created");
}
public byte[] createStreamingContext() {
streamingContext = StreamingContext.buildContext();
LOG.info("StreamingContext created");
referenceMap.put(getReferenceId(streamingContext), streamingContext);
return serializer.serialize(getReferenceId(streamingContext));
}
public StreamingContext getStreamingContext() {
return streamingContext;
}
public byte[] withConfig(byte[] confBytes) {
Preconditions.checkNotNull(streamingContext);
try {
Map<String, String> config = (Map<String, String>) serializer.deserialize(confBytes);
LOG.info("Set config {}", config);
streamingContext.withConfig(config);
// We can't use `return void`, that will make `ray.get()` hang forever.
// We can't using `return new byte[0]`, that will make `ray::CoreWorker::ExecuteTask` crash.
// So we `return new byte[1]` for method execution success.
// Same for other methods in this class which return new byte[1].
return new byte[1];
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public byte[] createPythonStreamSource(byte[] pySourceFunc) {
Preconditions.checkNotNull(streamingContext);
try {
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
streamingContext, PythonFunction.fromFunction(pySourceFunc));
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
return serializer.serialize(getReferenceId(pythonStreamSource));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public byte[] execute(byte[] jobNameBytes) {
LOG.info("Starting executing");
streamingContext.execute((String) serializer.deserialize(jobNameBytes));
// see `withConfig` method.
return new byte[1];
}
public byte[] createPyFunc(byte[] pyFunc) {
PythonFunction function = PythonFunction.fromFunction(pyFunc);
referenceMap.put(getReferenceId(function), function);
return serializer.serialize(getReferenceId(function));
}
public byte[] createPyPartition(byte[] pyPartition) {
PythonPartition partition = new PythonPartition(pyPartition);
referenceMap.put(getReferenceId(partition), partition);
return serializer.serialize(getReferenceId(partition));
}
public byte[] callFunction(byte[] paramsBytes) {
try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params);
LOG.info("callFunction params {}", params);
String className = (String) params.get(0);
String funcName = (String) params.get(1);
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
Method method = ReflectionUtils.findMethod(clz, funcName);
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public byte[] callMethod(byte[] paramsBytes) {
try {
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
params = processReferenceParameters(params);
LOG.info("callMethod params {}", params);
Object obj = params.get(0);
String methodName = (String) params.get(1);
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
referenceMap.put(getReferenceId(result), result);
return serializer.serialize(getReferenceId(result));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private List<Object> processReferenceParameters(List<Object> params) {
return params.stream().map(this::processReferenceParameter)
.collect(Collectors.toList());
}
private Object processReferenceParameter(Object o) {
if (o instanceof String) {
Object value = referenceMap.get(o);
if (value != null) {
return value;
}
}
return o;
}
private String getReferenceId(Object o) {
return REFERENCE_ID_PREFIX + System.identityHashCode(o);
}
}
@@ -6,12 +6,14 @@ import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.streaming.api.Language;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.jobgraph.JobVertex;
import org.ray.streaming.runtime.cluster.ResourceManager;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.core.graph.ExecutionNode;
import org.ray.streaming.runtime.core.graph.ExecutionTask;
import org.ray.streaming.runtime.generated.RemoteCall;
import org.ray.streaming.runtime.python.GraphPbBuilder;
import org.ray.streaming.runtime.worker.JobWorker;
import org.ray.streaming.runtime.worker.context.WorkerContext;
import org.ray.streaming.schedule.JobScheduler;
@@ -23,43 +25,70 @@ import org.ray.streaming.schedule.JobScheduler;
public class JobSchedulerImpl implements JobScheduler {
private JobGraph jobGraph;
private Map<String, String> jobConfig;
private ResourceManager resourceManager;
private TaskAssigner taskAssigner;
public JobSchedulerImpl() {
this.resourceManager = new ResourceManager();
this.taskAssigner = new TaskAssignerImpl();
}
/**
* Schedule physical plan to execution graph, and call streaming worker to init and run.
*/
@SuppressWarnings("unchecked")
@Override
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
this.jobConfig = jobConfig;
this.jobGraph = jobGraph;
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
Ray.init();
List<RayActor<JobWorker>> workers = this.resourceManager.createWorkers(getPlanWorker());
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph, workers);
if (Ray.internal() == null) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
Ray.init();
}
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
List<RayObject<Boolean>> waits = new ArrayList<>();
boolean hasPythonNode = executionNodes.stream()
.allMatch(node -> node.getLanguage() == Language.PYTHON);
RemoteCall.ExecutionGraph executionGraphPb = null;
if (hasPythonNode) {
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
}
List<RayObject<Object>> waits = new ArrayList<>();
for (ExecutionNode executionNode : executionNodes) {
List<ExecutionTask> executionTasks = executionNode.getExecutionTasks();
for (ExecutionTask executionTask : executionTasks) {
int taskId = executionTask.getTaskId();
RayActor<JobWorker> streamWorker = executionTask.getWorker();
waits.add(Ray.call(JobWorker::init, streamWorker,
new WorkerContext(taskId, executionGraph, jobConfig)));
RayActor worker = executionTask.getWorker();
switch (executionNode.getLanguage()) {
case JAVA:
RayActor<JobWorker> jobWorker = (RayActor<JobWorker>) worker;
waits.add(Ray.call(JobWorker::init, jobWorker,
new WorkerContext(taskId, executionGraph, jobConfig)));
break;
case PYTHON:
byte[] workerContextBytes = buildPythonWorkerContext(
taskId, executionGraphPb, jobConfig);
waits.add(Ray.callPy((RayPyActor) worker,
"init", workerContextBytes));
break;
default:
throw new UnsupportedOperationException(
"Unsupported language " + executionNode.getLanguage());
}
}
}
Ray.wait(waits);
}
private int getPlanWorker() {
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
return jobVertexList.stream().map(JobVertex::getParallelism).reduce(0, Integer::sum);
private byte[] buildPythonWorkerContext(
int taskId,
RemoteCall.ExecutionGraph executionGraphPb,
Map<String, String> jobConfig) {
return RemoteCall.WorkerContext.newBuilder()
.setTaskId(taskId)
.putAllConf(jobConfig)
.setGraph(executionGraphPb)
.build()
.toByteArray();
}
}
@@ -1,11 +1,8 @@
package org.ray.streaming.runtime.schedule;
import java.io.Serializable;
import java.util.List;
import org.ray.api.RayActor;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.worker.JobWorker;
/**
* Interface of the task assigning strategy.
@@ -15,6 +12,6 @@ public interface TaskAssigner extends Serializable {
/**
* Assign logical plan to physical execution graph.
*/
ExecutionGraph assign(JobGraph jobGraph, List<RayActor<JobWorker>> workers);
ExecutionGraph assign(JobGraph jobGraph);
}
@@ -4,7 +4,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.streaming.jobgraph.JobEdge;
import org.ray.streaming.jobgraph.JobGraph;
@@ -20,12 +20,11 @@ public class TaskAssignerImpl implements TaskAssigner {
/**
* Assign an optimized logical plan to execution graph.
*
* @param jobGraph The logical plan.
* @param workers The worker actors.
* @param jobGraph The logical plan.
* @return The physical execution graph.
*/
@Override
public ExecutionGraph assign(JobGraph jobGraph, List<RayActor<JobWorker>> workers) {
public ExecutionGraph assign(JobGraph jobGraph) {
List<JobVertex> jobVertices = jobGraph.getJobVertexList();
List<JobEdge> jobEdges = jobGraph.getJobEdgeList();
@@ -37,7 +36,7 @@ public class TaskAssignerImpl implements TaskAssigner {
executionNode.setNodeType(jobVertex.getVertexType());
List<ExecutionTask> vertexTasks = new ArrayList<>();
for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) {
vertexTasks.add(new ExecutionTask(taskId, taskIndex, workers.get(taskId)));
vertexTasks.add(new ExecutionTask(taskId, taskIndex, createWorker(jobVertex)));
taskId++;
}
executionNode.setExecutionTasks(vertexTasks);
@@ -51,12 +50,25 @@ public class TaskAssignerImpl implements TaskAssigner {
ExecutionEdge executionEdge = new ExecutionEdge(srcNodeId, targetNodeId,
jobEdge.getPartition());
idToExecutionNode.get(srcNodeId).addExecutionEdge(executionEdge);
idToExecutionNode.get(srcNodeId).addOutputEdge(executionEdge);
idToExecutionNode.get(targetNodeId).addInputEdge(executionEdge);
}
List<ExecutionNode> executionNodes = idToExecutionNode.values().stream()
.collect(Collectors.toList());
List<ExecutionNode> executionNodes = new ArrayList<>(idToExecutionNode.values());
return new ExecutionGraph(executionNodes);
}
private RayActor createWorker(JobVertex jobVertex) {
switch (jobVertex.getLanguage()) {
case PYTHON:
return Ray.createPyActor(
"ray.streaming.runtime.worker", "JobWorker");
case JAVA:
return Ray.createActor(JobWorker::new);
default:
throw new UnsupportedOperationException(
"Unsupported language " + jobVertex.getLanguage());
}
}
}
@@ -0,0 +1,86 @@
package org.ray.streaming.runtime.util;
import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
@SuppressWarnings("UnstableApiUsage")
public class ReflectionUtils {
public static Method findMethod(Class<?> cls, String methodName) {
List<Method> methods = findMethods(cls, methodName);
Preconditions.checkArgument(methods.size() == 1);
return methods.get(0);
}
/**
* For covariant return type, return the most specific method.
* @return all methods named by {@code methodName},
*/
public static List<Method> findMethods(Class<?> cls, String methodName) {
List<Class<?>> classes = new ArrayList<>();
Class<?> clazz = cls;
while (clazz != null) {
classes.add(clazz);
clazz = clazz.getSuperclass();
}
classes.addAll(getAllInterfaces(cls));
if (classes.indexOf(Object.class) == -1) {
classes.add(Object.class);
}
LinkedHashMap<List<Class<?>>, Method> methods = new LinkedHashMap<>();
for (Class<?> superClass : classes) {
for (Method m : superClass.getDeclaredMethods()) {
if (m.getName().equals(methodName)) {
List<Class<?>> params = Arrays.asList(m.getParameterTypes());
Method method = methods.get(params);
if (method == null) {
methods.put(params, m);
} else {
// for covariant return type, use the most specific method
if (method.getReturnType().isAssignableFrom(m.getReturnType())) {
methods.put(params, m);
}
}
}
}
}
return new ArrayList<>(methods.values());
}
/**
* <p>Gets a <code>List</code> of all interfaces implemented by the given
* class and its superclasses.</p>
* <p>The order is determined by looking through each interface in turn as
* declared in the source file and following its hierarchy up.</p>
*/
public static List<Class<?>> getAllInterfaces(Class<?> cls) {
if (cls == null) {
return null;
}
LinkedHashSet<Class<?>> interfacesFound = new LinkedHashSet<>();
getAllInterfaces(cls, interfacesFound);
return new ArrayList<>(interfacesFound);
}
private static void getAllInterfaces(Class<?> cls, LinkedHashSet<Class<?>> interfacesFound) {
while (cls != null) {
Class[] interfaces = cls.getInterfaces();
for (Class anInterface : interfaces) {
if (!interfacesFound.contains(anInterface)) {
interfacesFound.add(anInterface);
getAllInterfaces(anInterface, interfacesFound);
}
}
cls = cls.getSuperclass();
}
}
}
@@ -2,6 +2,7 @@ package org.ray.streaming.runtime.worker;
import java.io.Serializable;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.annotation.RayRemote;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
@@ -65,7 +65,7 @@ public abstract class StreamTask implements Runnable {
List<Collector> collectors = new ArrayList<>();
for (ExecutionEdge edge : outputEdges) {
Map<String, ActorId> outputActorIds = new HashMap<>();
Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph
Map<Integer, RayActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
taskId2Worker.forEach((targetTaskId, targetActor) -> {
String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
@@ -91,7 +91,7 @@ public abstract class StreamTask implements Runnable {
List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
Map<String, ActorId> inputActorIds = new HashMap<>();
for (ExecutionEdge edge : inputEdges) {
Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph
Map<Integer, RayActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
taskId2Worker.forEach((srcTaskId, srcActor) -> {
String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());
@@ -0,0 +1,39 @@
package org.ray.streaming.runtime.python;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.testng.annotations.Test;
@SuppressWarnings("unchecked")
public class MsgPackSerializerTest {
@Test
public void testSerialize() {
MsgPackSerializer serializer = new MsgPackSerializer();
Map map = new HashMap();
List list = new ArrayList<>();
list.add(null);
list.add(true);
list.add(1);
list.add(1.0d);
list.add("str");
map.put("k1", "value1");
map.put("k2", 2);
map.put("k3", list);
byte[] bytes = serializer.serialize(map);
Object o = serializer.deserialize(bytes);
assertEquals(o, map);
byte[] binary = {1, 2, 3, 4};
assertTrue(Arrays.equals(
binary, (byte[]) (serializer.deserialize(serializer.serialize(binary)))));
}
}
@@ -0,0 +1,48 @@
package org.ray.streaming.runtime.python;
import static org.testng.Assert.assertEquals;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.streaming.api.stream.StreamSink;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.jobgraph.JobGraphBuilder;
import org.testng.annotations.Test;
public class PythonGatewayTest {
@Test
public void testPythonGateway() {
MsgPackSerializer serializer = new MsgPackSerializer();
PythonGateway gateway = new PythonGateway();
gateway.createStreamingContext();
Map<String, String> config = new HashMap<>();
config.put("k1", "v1");
gateway.withConfig(serializer.serialize(config));
byte[] mockPySource = new byte[0];
Object source = serializer.deserialize(
gateway.createPythonStreamSource(mockPySource));
byte[] mockPyFunc = new byte[0];
Object mapPyFunc = serializer.deserialize(gateway.createPyFunc(mockPyFunc));
Object mapStream = serializer.deserialize(
gateway.callMethod(
serializer.serialize(Arrays.asList(source, "map", mapPyFunc))));
byte[] mockPyPartition = new byte[0];
Object partition = serializer.deserialize(
gateway.createPyPartition(mockPyPartition));
Object partitionedStream = serializer.deserialize(
gateway.callMethod(
serializer.serialize(Arrays.asList(mapStream, "partitionBy", partition))));
byte[] mockSinkFunc = new byte[0];
Object sinkPyFunc = serializer.deserialize(gateway.createPyFunc(mockSinkFunc));
gateway.callMethod(
serializer.serialize(Arrays.asList(partitionedStream, "sink", sinkPyFunc)));
List<StreamSink> streamSinks = gateway.getStreamingContext().getStreamSinks();
assertEquals(streamSinks.size(), 1);
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(streamSinks, "py_job");
JobGraph jobGraph = jobGraphBuilder.build();
jobGraph.printJobGraph();
}
}
@@ -1,26 +1,20 @@
package org.ray.streaming.runtime.schedule;
import java.util.ArrayList;
import java.util.List;
import com.google.common.collect.Lists;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
import org.ray.runtime.actor.LocalModeRayActor;
import java.util.List;
import org.ray.api.Ray;
import org.ray.streaming.api.context.StreamingContext;
import org.ray.streaming.api.partition.impl.RoundRobinPartition;
import org.ray.streaming.api.stream.DataStream;
import org.ray.streaming.api.stream.DataStreamSink;
import org.ray.streaming.api.stream.DataStreamSource;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.jobgraph.JobGraphBuilder;
import org.ray.streaming.runtime.BaseUnitTest;
import org.ray.streaming.runtime.core.graph.ExecutionEdge;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
import org.ray.streaming.runtime.core.graph.ExecutionNode;
import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
import org.ray.streaming.runtime.worker.JobWorker;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.jobgraph.JobGraphBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
@@ -32,15 +26,11 @@ public class TaskAssignerImplTest extends BaseUnitTest {
@Test
public void testTaskAssignImpl() {
Ray.init();
JobGraph jobGraph = buildDataSyncPlan();
List<RayActor<JobWorker>> workers = new ArrayList<>();
for(int i = 0; i < jobGraph.getJobVertexList().size(); i++) {
workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom()));
}
TaskAssigner taskAssigner = new TaskAssignerImpl();
ExecutionGraph executionGraph = taskAssigner.assign(jobGraph, workers);
ExecutionGraph executionGraph = taskAssigner.assign(jobGraph);
List<ExecutionNode> executionNodeList = executionGraph.getExecutionNodeList();
@@ -61,16 +51,17 @@ public class TaskAssignerImplTest extends BaseUnitTest {
Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK);
Assert.assertEquals(sinkNode.getExecutionTasks().size(), 1);
Assert.assertEquals(sinkNode.getOutputEdges().size(), 0);
Ray.shutdown();
}
public JobGraph buildDataSyncPlan() {
StreamingContext streamingContext = StreamingContext.buildContext();
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
Lists.newArrayList("a", "b", "c"));
DataStreamSink streamSink = dataStream.sink(x -> LOGGER.info(x));
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
JobGraph jobGraph = jobGraphBuilder.build();
return jobGraph;
return jobGraphBuilder.build();
}
}
@@ -0,0 +1,38 @@
package org.ray.streaming.runtime.util;
import static org.testng.Assert.assertEquals;
import java.io.Serializable;
import java.util.Collections;
import org.testng.annotations.Test;
public class ReflectionUtilsTest {
static class Foo implements Serializable {
public void f1() {
}
public void f2() {
}
public void f2(boolean a1) {
}
}
@Test
public void testFindMethod() throws NoSuchMethodException {
assertEquals(Foo.class.getDeclaredMethod("f1"),
ReflectionUtils.findMethod(Foo.class, "f1"));
}
@Test
public void testFindMethods() {
assertEquals(ReflectionUtils.findMethods(Foo.class, "f2").size(), 2);
}
@Test
public void testGetAllInterfaces() {
assertEquals(ReflectionUtils.getAllInterfaces(Foo.class),
Collections.singletonList(Serializable.class));
}
}