mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:05:07 +08:00
[Streaming] Streaming Python API (#6755)
This commit is contained in:
-23
@@ -1,23 +0,0 @@
|
||||
package org.ray.streaming.runtime.cluster;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
|
||||
/**
|
||||
* Resource-Manager is used to do the management of resources
|
||||
*/
|
||||
public class ResourceManager {
|
||||
|
||||
public List<RayActor<JobWorker>> createWorkers(int workerNum) {
|
||||
List<RayActor<JobWorker>> workers = new ArrayList<>();
|
||||
for (int i = 0; i < workerNum; i++) {
|
||||
RayActor<JobWorker> worker = Ray.createActor(JobWorker::new);
|
||||
workers.add(worker);
|
||||
}
|
||||
return workers;
|
||||
}
|
||||
|
||||
}
|
||||
+8
-9
@@ -7,7 +7,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
|
||||
/**
|
||||
* Physical execution graph.
|
||||
@@ -19,19 +18,19 @@ import org.ray.streaming.runtime.worker.JobWorker;
|
||||
public class ExecutionGraph implements Serializable {
|
||||
private long buildTime;
|
||||
private List<ExecutionNode> executionNodeList;
|
||||
private List<RayActor<JobWorker>> sourceWorkers = new ArrayList<>();
|
||||
private List<RayActor<JobWorker>> sinkWorkers = new ArrayList<>();
|
||||
private List<RayActor> sourceWorkers = new ArrayList<>();
|
||||
private List<RayActor> sinkWorkers = new ArrayList<>();
|
||||
|
||||
public ExecutionGraph(List<ExecutionNode> executionNodes) {
|
||||
this.executionNodeList = executionNodes;
|
||||
for (ExecutionNode executionNode : executionNodeList) {
|
||||
if (executionNode.getNodeType() == ExecutionNode.NodeType.SOURCE) {
|
||||
List<RayActor<JobWorker>> actors = executionNode.getExecutionTasks().stream()
|
||||
List<RayActor> actors = executionNode.getExecutionTasks().stream()
|
||||
.map(ExecutionTask::getWorker).collect(Collectors.toList());
|
||||
sourceWorkers.addAll(actors);
|
||||
}
|
||||
if (executionNode.getNodeType() == ExecutionNode.NodeType.SINK) {
|
||||
List<RayActor<JobWorker>> actors = executionNode.getExecutionTasks().stream()
|
||||
List<RayActor> actors = executionNode.getExecutionTasks().stream()
|
||||
.map(ExecutionTask::getWorker).collect(Collectors.toList());
|
||||
sinkWorkers.addAll(actors);
|
||||
}
|
||||
@@ -39,11 +38,11 @@ public class ExecutionGraph implements Serializable {
|
||||
buildTime = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
public List<RayActor<JobWorker>> getSourceWorkers() {
|
||||
public List<RayActor> getSourceWorkers() {
|
||||
return sourceWorkers;
|
||||
}
|
||||
|
||||
public List<RayActor<JobWorker>> getSinkWorkers() {
|
||||
public List<RayActor> getSinkWorkers() {
|
||||
return sinkWorkers;
|
||||
}
|
||||
|
||||
@@ -82,10 +81,10 @@ public class ExecutionGraph implements Serializable {
|
||||
throw new RuntimeException("Task " + taskId + " does not exist!");
|
||||
}
|
||||
|
||||
public Map<Integer, RayActor<JobWorker>> getTaskId2WorkerByNodeId(int nodeId) {
|
||||
public Map<Integer, RayActor> getTaskId2WorkerByNodeId(int nodeId) {
|
||||
for (ExecutionNode executionNode : executionNodeList) {
|
||||
if (executionNode.getNodeId() == nodeId) {
|
||||
Map<Integer, RayActor<JobWorker>> taskId2Worker = new HashMap<>();
|
||||
Map<Integer, RayActor> taskId2Worker = new HashMap<>();
|
||||
for (ExecutionTask executionTask : executionNode.getExecutionTasks()) {
|
||||
taskId2Worker.put(executionTask.getTaskId(), executionTask.getWorker());
|
||||
}
|
||||
|
||||
+8
-4
@@ -3,6 +3,7 @@ package org.ray.streaming.runtime.core.graph;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.streaming.api.Language;
|
||||
import org.ray.streaming.jobgraph.VertexType;
|
||||
import org.ray.streaming.operator.StreamOperator;
|
||||
|
||||
@@ -10,7 +11,6 @@ import org.ray.streaming.operator.StreamOperator;
|
||||
* A node in the physical execution graph.
|
||||
*/
|
||||
public class ExecutionNode implements Serializable {
|
||||
|
||||
private int nodeId;
|
||||
private int parallelism;
|
||||
private NodeType nodeType;
|
||||
@@ -59,7 +59,7 @@ public class ExecutionNode implements Serializable {
|
||||
this.outputEdges = outputEdges;
|
||||
}
|
||||
|
||||
public void addExecutionEdge(ExecutionEdge executionEdge) {
|
||||
public void addOutputEdge(ExecutionEdge executionEdge) {
|
||||
this.outputEdges.add(executionEdge);
|
||||
}
|
||||
|
||||
@@ -79,6 +79,10 @@ public class ExecutionNode implements Serializable {
|
||||
this.streamOperator = streamOperator;
|
||||
}
|
||||
|
||||
public Language getLanguage() {
|
||||
return streamOperator.getLanguage();
|
||||
}
|
||||
|
||||
public NodeType getNodeType() {
|
||||
return nodeType;
|
||||
}
|
||||
@@ -92,7 +96,7 @@ public class ExecutionNode implements Serializable {
|
||||
this.nodeType = NodeType.SINK;
|
||||
break;
|
||||
default:
|
||||
this.nodeType = NodeType.PROCESS;
|
||||
this.nodeType = NodeType.TRANSFORM;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +113,7 @@ public class ExecutionNode implements Serializable {
|
||||
|
||||
public enum NodeType {
|
||||
SOURCE,
|
||||
PROCESS,
|
||||
TRANSFORM,
|
||||
SINK,
|
||||
}
|
||||
}
|
||||
|
||||
+4
-5
@@ -2,7 +2,6 @@ package org.ray.streaming.runtime.core.graph;
|
||||
|
||||
import java.io.Serializable;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
|
||||
/**
|
||||
* ExecutionTask is minimal execution unit.
|
||||
@@ -12,9 +11,9 @@ import org.ray.streaming.runtime.worker.JobWorker;
|
||||
public class ExecutionTask implements Serializable {
|
||||
private int taskId;
|
||||
private int taskIndex;
|
||||
private RayActor<JobWorker> worker;
|
||||
private RayActor worker;
|
||||
|
||||
public ExecutionTask(int taskId, int taskIndex, RayActor<JobWorker> worker) {
|
||||
public ExecutionTask(int taskId, int taskIndex, RayActor worker) {
|
||||
this.taskId = taskId;
|
||||
this.taskIndex = taskIndex;
|
||||
this.worker = worker;
|
||||
@@ -36,11 +35,11 @@ public class ExecutionTask implements Serializable {
|
||||
this.taskIndex = taskIndex;
|
||||
}
|
||||
|
||||
public RayActor<JobWorker> getWorker() {
|
||||
public RayActor getWorker() {
|
||||
return worker;
|
||||
}
|
||||
|
||||
public void setWorker(RayActor<JobWorker> worker) {
|
||||
public void setWorker(RayActor worker) {
|
||||
this.worker = worker;
|
||||
}
|
||||
}
|
||||
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
package org.ray.streaming.runtime.python;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.util.Arrays;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.streaming.api.function.Function;
|
||||
import org.ray.streaming.api.partition.Partition;
|
||||
import org.ray.streaming.python.PythonFunction;
|
||||
import org.ray.streaming.python.PythonPartition;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionEdge;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionTask;
|
||||
import org.ray.streaming.runtime.generated.RemoteCall;
|
||||
import org.ray.streaming.runtime.generated.Streaming;
|
||||
|
||||
public class GraphPbBuilder {
|
||||
|
||||
private MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
|
||||
/**
|
||||
* For simple scenario, a single ExecutionNode is enough. But some cases may need
|
||||
* sub-graph information, so we serialize entire graph.
|
||||
*/
|
||||
public RemoteCall.ExecutionGraph buildExecutionGraphPb(ExecutionGraph graph) {
|
||||
RemoteCall.ExecutionGraph.Builder builder = RemoteCall.ExecutionGraph.newBuilder();
|
||||
builder.setBuildTime(graph.getBuildTime());
|
||||
for (ExecutionNode node : graph.getExecutionNodeList()) {
|
||||
RemoteCall.ExecutionGraph.ExecutionNode.Builder nodeBuilder =
|
||||
RemoteCall.ExecutionGraph.ExecutionNode.newBuilder();
|
||||
nodeBuilder.setNodeId(node.getNodeId());
|
||||
nodeBuilder.setParallelism(node.getParallelism());
|
||||
nodeBuilder.setNodeType(
|
||||
Streaming.NodeType.valueOf(node.getNodeType().name()));
|
||||
nodeBuilder.setLanguage(Streaming.Language.valueOf(node.getLanguage().name()));
|
||||
byte[] functionBytes = serializeFunction(node.getStreamOperator().getFunction());
|
||||
nodeBuilder.setFunction(ByteString.copyFrom(functionBytes));
|
||||
|
||||
// build tasks
|
||||
for (ExecutionTask task : node.getExecutionTasks()) {
|
||||
RemoteCall.ExecutionGraph.ExecutionTask.Builder taskBuilder =
|
||||
RemoteCall.ExecutionGraph.ExecutionTask.newBuilder();
|
||||
byte[] serializedActorHandle = ((NativeRayActor) task.getWorker()).toBytes();
|
||||
taskBuilder
|
||||
.setTaskId(task.getTaskId())
|
||||
.setTaskIndex(task.getTaskIndex())
|
||||
.setWorkerActor(ByteString.copyFrom(serializedActorHandle));
|
||||
nodeBuilder.addExecutionTasks(taskBuilder.build());
|
||||
}
|
||||
|
||||
// build edges
|
||||
for (ExecutionEdge edge : node.getInputsEdges()) {
|
||||
nodeBuilder.addInputEdges(buildEdgePb(edge));
|
||||
}
|
||||
for (ExecutionEdge edge : node.getOutputEdges()) {
|
||||
nodeBuilder.addOutputEdges(buildEdgePb(edge));
|
||||
}
|
||||
|
||||
builder.addExecutionNodes(nodeBuilder.build());
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private RemoteCall.ExecutionGraph.ExecutionEdge buildEdgePb(ExecutionEdge edge) {
|
||||
RemoteCall.ExecutionGraph.ExecutionEdge.Builder edgeBuilder =
|
||||
RemoteCall.ExecutionGraph.ExecutionEdge.newBuilder();
|
||||
edgeBuilder.setSrcNodeId(edge.getSrcNodeId());
|
||||
edgeBuilder.setTargetNodeId(edge.getTargetNodeId());
|
||||
edgeBuilder.setPartition(ByteString.copyFrom(serializePartition(edge.getPartition())));
|
||||
return edgeBuilder.build();
|
||||
}
|
||||
|
||||
private byte[] serializeFunction(Function function) {
|
||||
if (function instanceof PythonFunction) {
|
||||
PythonFunction pyFunc = (PythonFunction) function;
|
||||
// function_bytes, module_name, class_name, function_name, function_interface
|
||||
return serializer.serialize(Arrays.asList(
|
||||
pyFunc.getFunction(), pyFunc.getModuleName(),
|
||||
pyFunc.getClassName(), pyFunc.getFunctionName(),
|
||||
pyFunc.getFunctionInterface()
|
||||
));
|
||||
} else {
|
||||
return new byte[0];
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] serializePartition(Partition partition) {
|
||||
if (partition instanceof PythonPartition) {
|
||||
PythonPartition pythonPartition = (PythonPartition) partition;
|
||||
// partition_bytes, module_name, class_name, function_name
|
||||
return serializer.serialize(Arrays.asList(
|
||||
pythonPartition.getPartition(), pythonPartition.getModuleName(),
|
||||
pythonPartition.getClassName(), pythonPartition.getFunctionName()
|
||||
));
|
||||
} else {
|
||||
return new byte[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
+119
@@ -0,0 +1,119 @@
|
||||
package org.ray.streaming.runtime.python;
|
||||
|
||||
import com.google.common.io.BaseEncoding;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.msgpack.core.MessageBufferPacker;
|
||||
import org.msgpack.core.MessagePack;
|
||||
import org.msgpack.core.MessageUnpacker;
|
||||
import org.msgpack.value.ArrayValue;
|
||||
import org.msgpack.value.FloatValue;
|
||||
import org.msgpack.value.IntegerValue;
|
||||
import org.msgpack.value.MapValue;
|
||||
import org.msgpack.value.Value;
|
||||
|
||||
public class MsgPackSerializer {
|
||||
|
||||
public byte[] serialize(Object obj) {
|
||||
MessageBufferPacker packer = MessagePack.newDefaultBufferPacker();
|
||||
serialize(obj, packer);
|
||||
return packer.toByteArray();
|
||||
}
|
||||
|
||||
private void serialize(Object obj, MessageBufferPacker packer) {
|
||||
try {
|
||||
if (obj == null) {
|
||||
packer.packNil();
|
||||
} else {
|
||||
Class<?> clz = obj.getClass();
|
||||
if (clz == Boolean.class) {
|
||||
packer.packBoolean((Boolean) obj);
|
||||
} else if (clz == Integer.class) {
|
||||
packer.packInt((Integer) obj);
|
||||
} else if (clz == Long.class) {
|
||||
packer.packLong((Long) obj);
|
||||
} else if (clz == Double.class) {
|
||||
packer.packDouble((Double) obj);
|
||||
} else if (clz == byte[].class) {
|
||||
byte[] bytes = (byte[]) obj;
|
||||
packer.packBinaryHeader(bytes.length);
|
||||
packer.writePayload(bytes);
|
||||
} else if (clz == String.class) {
|
||||
packer.packString((String) obj);
|
||||
} else if (obj instanceof Collection) {
|
||||
Collection collection = (Collection) (obj);
|
||||
packer.packArrayHeader(collection.size());
|
||||
for (Object o : collection) {
|
||||
serialize(o, packer);
|
||||
}
|
||||
} else if (obj instanceof Map) {
|
||||
Map map = (Map) (obj);
|
||||
packer.packMapHeader(map.size());
|
||||
for (Object o : map.entrySet()) {
|
||||
Map.Entry e = (Map.Entry) o;
|
||||
serialize(e.getKey(), packer);
|
||||
serialize(e.getValue(), packer);
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unsupported type " + clz);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Serialize error for object " + obj, e);
|
||||
}
|
||||
}
|
||||
|
||||
public Object deserialize(byte[] bytes) {
|
||||
try {
|
||||
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bytes);
|
||||
return deserialize(unpacker.unpackValue());
|
||||
} catch (Exception e) {
|
||||
String hex = BaseEncoding.base16().lowerCase().encode(bytes);
|
||||
throw new RuntimeException("Deserialize error: " + hex, e);
|
||||
}
|
||||
}
|
||||
|
||||
private Object deserialize(Value value) {
|
||||
switch (value.getValueType()) {
|
||||
case NIL:
|
||||
return null;
|
||||
case BOOLEAN:
|
||||
return value.asBooleanValue().getBoolean();
|
||||
case INTEGER:
|
||||
IntegerValue iv = value.asIntegerValue();
|
||||
if (iv.isInIntRange()) {
|
||||
return iv.toInt();
|
||||
} else if (iv.isInLongRange()) {
|
||||
return iv.toLong();
|
||||
} else {
|
||||
return iv.toBigInteger();
|
||||
}
|
||||
case FLOAT:
|
||||
FloatValue fv = value.asFloatValue();
|
||||
return fv.toDouble();
|
||||
case STRING:
|
||||
return value.asStringValue().asString();
|
||||
case BINARY:
|
||||
return value.asBinaryValue().asByteArray();
|
||||
case ARRAY:
|
||||
ArrayValue arrayValue = value.asArrayValue();
|
||||
List<Object> list = new ArrayList<>(arrayValue.size());
|
||||
for (Value elem : arrayValue) {
|
||||
list.add(deserialize(elem));
|
||||
}
|
||||
return list;
|
||||
case MAP:
|
||||
MapValue mapValue = value.asMapValue();
|
||||
Map<Object, Object> map = new HashMap<>();
|
||||
for (Map.Entry<Value, Value> entry : mapValue.entrySet()) {
|
||||
map.put(deserialize(entry.getKey()), deserialize(entry.getValue()));
|
||||
}
|
||||
return map;
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unsupported type " + value.getValueType());
|
||||
}
|
||||
}
|
||||
}
|
||||
+152
@@ -0,0 +1,152 @@
|
||||
package org.ray.streaming.runtime.python;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.msgpack.core.Preconditions;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.streaming.api.context.StreamingContext;
|
||||
import org.ray.streaming.python.PythonFunction;
|
||||
import org.ray.streaming.python.PythonPartition;
|
||||
import org.ray.streaming.python.stream.PythonStreamSource;
|
||||
import org.ray.streaming.runtime.util.ReflectionUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Gateway for streaming python api.
|
||||
* All calls on DataStream in python will be mapped to DataStream call in java by this
|
||||
* PythonGateway using ray calls.
|
||||
* <p>
|
||||
* Note: this class needs to be in sync with `GatewayClient` in
|
||||
* `streaming/python/runtime/gateway_client.py`
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
@RayRemote
|
||||
public class PythonGateway {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class);
|
||||
private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__";
|
||||
|
||||
private MsgPackSerializer serializer;
|
||||
private Map<String, Object> referenceMap;
|
||||
private StreamingContext streamingContext;
|
||||
|
||||
public PythonGateway() {
|
||||
serializer = new MsgPackSerializer();
|
||||
referenceMap = new HashMap<>();
|
||||
LOG.info("PythonGateway created");
|
||||
}
|
||||
|
||||
public byte[] createStreamingContext() {
|
||||
streamingContext = StreamingContext.buildContext();
|
||||
LOG.info("StreamingContext created");
|
||||
referenceMap.put(getReferenceId(streamingContext), streamingContext);
|
||||
return serializer.serialize(getReferenceId(streamingContext));
|
||||
}
|
||||
|
||||
public StreamingContext getStreamingContext() {
|
||||
return streamingContext;
|
||||
}
|
||||
|
||||
public byte[] withConfig(byte[] confBytes) {
|
||||
Preconditions.checkNotNull(streamingContext);
|
||||
try {
|
||||
Map<String, String> config = (Map<String, String>) serializer.deserialize(confBytes);
|
||||
LOG.info("Set config {}", config);
|
||||
streamingContext.withConfig(config);
|
||||
// We can't use `return void`, that will make `ray.get()` hang forever.
|
||||
// We can't using `return new byte[0]`, that will make `ray::CoreWorker::ExecuteTask` crash.
|
||||
// So we `return new byte[1]` for method execution success.
|
||||
// Same for other methods in this class which return new byte[1].
|
||||
return new byte[1];
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] createPythonStreamSource(byte[] pySourceFunc) {
|
||||
Preconditions.checkNotNull(streamingContext);
|
||||
try {
|
||||
PythonStreamSource pythonStreamSource = PythonStreamSource.from(
|
||||
streamingContext, PythonFunction.fromFunction(pySourceFunc));
|
||||
referenceMap.put(getReferenceId(pythonStreamSource), pythonStreamSource);
|
||||
return serializer.serialize(getReferenceId(pythonStreamSource));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] execute(byte[] jobNameBytes) {
|
||||
LOG.info("Starting executing");
|
||||
streamingContext.execute((String) serializer.deserialize(jobNameBytes));
|
||||
// see `withConfig` method.
|
||||
return new byte[1];
|
||||
}
|
||||
|
||||
public byte[] createPyFunc(byte[] pyFunc) {
|
||||
PythonFunction function = PythonFunction.fromFunction(pyFunc);
|
||||
referenceMap.put(getReferenceId(function), function);
|
||||
return serializer.serialize(getReferenceId(function));
|
||||
}
|
||||
|
||||
public byte[] createPyPartition(byte[] pyPartition) {
|
||||
PythonPartition partition = new PythonPartition(pyPartition);
|
||||
referenceMap.put(getReferenceId(partition), partition);
|
||||
return serializer.serialize(getReferenceId(partition));
|
||||
}
|
||||
|
||||
public byte[] callFunction(byte[] paramsBytes) {
|
||||
try {
|
||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||
params = processReferenceParameters(params);
|
||||
LOG.info("callFunction params {}", params);
|
||||
String className = (String) params.get(0);
|
||||
String funcName = (String) params.get(1);
|
||||
Class<?> clz = Class.forName(className, true, this.getClass().getClassLoader());
|
||||
Method method = ReflectionUtils.findMethod(clz, funcName);
|
||||
Object result = method.invoke(null, params.subList(2, params.size()).toArray());
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] callMethod(byte[] paramsBytes) {
|
||||
try {
|
||||
List<Object> params = (List<Object>) serializer.deserialize(paramsBytes);
|
||||
params = processReferenceParameters(params);
|
||||
LOG.info("callMethod params {}", params);
|
||||
Object obj = params.get(0);
|
||||
String methodName = (String) params.get(1);
|
||||
Method method = ReflectionUtils.findMethod(obj.getClass(), methodName);
|
||||
Object result = method.invoke(obj, params.subList(2, params.size()).toArray());
|
||||
referenceMap.put(getReferenceId(result), result);
|
||||
return serializer.serialize(getReferenceId(result));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Object> processReferenceParameters(List<Object> params) {
|
||||
return params.stream().map(this::processReferenceParameter)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Object processReferenceParameter(Object o) {
|
||||
if (o instanceof String) {
|
||||
Object value = referenceMap.get(o);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return o;
|
||||
}
|
||||
|
||||
private String getReferenceId(Object o) {
|
||||
return REFERENCE_ID_PREFIX + System.identityHashCode(o);
|
||||
}
|
||||
|
||||
}
|
||||
+45
-16
@@ -6,12 +6,14 @@ import java.util.Map;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.streaming.api.Language;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
import org.ray.streaming.jobgraph.JobVertex;
|
||||
import org.ray.streaming.runtime.cluster.ResourceManager;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionTask;
|
||||
import org.ray.streaming.runtime.generated.RemoteCall;
|
||||
import org.ray.streaming.runtime.python.GraphPbBuilder;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
import org.ray.streaming.runtime.worker.context.WorkerContext;
|
||||
import org.ray.streaming.schedule.JobScheduler;
|
||||
@@ -23,43 +25,70 @@ import org.ray.streaming.schedule.JobScheduler;
|
||||
public class JobSchedulerImpl implements JobScheduler {
|
||||
private JobGraph jobGraph;
|
||||
private Map<String, String> jobConfig;
|
||||
private ResourceManager resourceManager;
|
||||
private TaskAssigner taskAssigner;
|
||||
|
||||
public JobSchedulerImpl() {
|
||||
this.resourceManager = new ResourceManager();
|
||||
this.taskAssigner = new TaskAssignerImpl();
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule physical plan to execution graph, and call streaming worker to init and run.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public void schedule(JobGraph jobGraph, Map<String, String> jobConfig) {
|
||||
this.jobConfig = jobConfig;
|
||||
this.jobGraph = jobGraph;
|
||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
Ray.init();
|
||||
|
||||
List<RayActor<JobWorker>> workers = this.resourceManager.createWorkers(getPlanWorker());
|
||||
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph, workers);
|
||||
if (Ray.internal() == null) {
|
||||
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
ExecutionGraph executionGraph = this.taskAssigner.assign(this.jobGraph);
|
||||
List<ExecutionNode> executionNodes = executionGraph.getExecutionNodeList();
|
||||
List<RayObject<Boolean>> waits = new ArrayList<>();
|
||||
boolean hasPythonNode = executionNodes.stream()
|
||||
.allMatch(node -> node.getLanguage() == Language.PYTHON);
|
||||
RemoteCall.ExecutionGraph executionGraphPb = null;
|
||||
if (hasPythonNode) {
|
||||
executionGraphPb = new GraphPbBuilder().buildExecutionGraphPb(executionGraph);
|
||||
}
|
||||
List<RayObject<Object>> waits = new ArrayList<>();
|
||||
for (ExecutionNode executionNode : executionNodes) {
|
||||
List<ExecutionTask> executionTasks = executionNode.getExecutionTasks();
|
||||
for (ExecutionTask executionTask : executionTasks) {
|
||||
int taskId = executionTask.getTaskId();
|
||||
RayActor<JobWorker> streamWorker = executionTask.getWorker();
|
||||
waits.add(Ray.call(JobWorker::init, streamWorker,
|
||||
new WorkerContext(taskId, executionGraph, jobConfig)));
|
||||
RayActor worker = executionTask.getWorker();
|
||||
switch (executionNode.getLanguage()) {
|
||||
case JAVA:
|
||||
RayActor<JobWorker> jobWorker = (RayActor<JobWorker>) worker;
|
||||
waits.add(Ray.call(JobWorker::init, jobWorker,
|
||||
new WorkerContext(taskId, executionGraph, jobConfig)));
|
||||
break;
|
||||
case PYTHON:
|
||||
byte[] workerContextBytes = buildPythonWorkerContext(
|
||||
taskId, executionGraphPb, jobConfig);
|
||||
waits.add(Ray.callPy((RayPyActor) worker,
|
||||
"init", workerContextBytes));
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + executionNode.getLanguage());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ray.wait(waits);
|
||||
}
|
||||
|
||||
private int getPlanWorker() {
|
||||
List<JobVertex> jobVertexList = jobGraph.getJobVertexList();
|
||||
return jobVertexList.stream().map(JobVertex::getParallelism).reduce(0, Integer::sum);
|
||||
private byte[] buildPythonWorkerContext(
|
||||
int taskId,
|
||||
RemoteCall.ExecutionGraph executionGraphPb,
|
||||
Map<String, String> jobConfig) {
|
||||
return RemoteCall.WorkerContext.newBuilder()
|
||||
.setTaskId(taskId)
|
||||
.putAllConf(jobConfig)
|
||||
.setGraph(executionGraphPb)
|
||||
.build()
|
||||
.toByteArray();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+1
-4
@@ -1,11 +1,8 @@
|
||||
package org.ray.streaming.runtime.schedule;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
|
||||
/**
|
||||
* Interface of the task assigning strategy.
|
||||
@@ -15,6 +12,6 @@ public interface TaskAssigner extends Serializable {
|
||||
/**
|
||||
* Assign logical plan to physical execution graph.
|
||||
*/
|
||||
ExecutionGraph assign(JobGraph jobGraph, List<RayActor<JobWorker>> workers);
|
||||
ExecutionGraph assign(JobGraph jobGraph);
|
||||
|
||||
}
|
||||
|
||||
+20
-8
@@ -4,7 +4,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.streaming.jobgraph.JobEdge;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
@@ -20,12 +20,11 @@ public class TaskAssignerImpl implements TaskAssigner {
|
||||
/**
|
||||
* Assign an optimized logical plan to execution graph.
|
||||
*
|
||||
* @param jobGraph The logical plan.
|
||||
* @param workers The worker actors.
|
||||
* @param jobGraph The logical plan.
|
||||
* @return The physical execution graph.
|
||||
*/
|
||||
@Override
|
||||
public ExecutionGraph assign(JobGraph jobGraph, List<RayActor<JobWorker>> workers) {
|
||||
public ExecutionGraph assign(JobGraph jobGraph) {
|
||||
List<JobVertex> jobVertices = jobGraph.getJobVertexList();
|
||||
List<JobEdge> jobEdges = jobGraph.getJobEdgeList();
|
||||
|
||||
@@ -37,7 +36,7 @@ public class TaskAssignerImpl implements TaskAssigner {
|
||||
executionNode.setNodeType(jobVertex.getVertexType());
|
||||
List<ExecutionTask> vertexTasks = new ArrayList<>();
|
||||
for (int taskIndex = 0; taskIndex < jobVertex.getParallelism(); taskIndex++) {
|
||||
vertexTasks.add(new ExecutionTask(taskId, taskIndex, workers.get(taskId)));
|
||||
vertexTasks.add(new ExecutionTask(taskId, taskIndex, createWorker(jobVertex)));
|
||||
taskId++;
|
||||
}
|
||||
executionNode.setExecutionTasks(vertexTasks);
|
||||
@@ -51,12 +50,25 @@ public class TaskAssignerImpl implements TaskAssigner {
|
||||
|
||||
ExecutionEdge executionEdge = new ExecutionEdge(srcNodeId, targetNodeId,
|
||||
jobEdge.getPartition());
|
||||
idToExecutionNode.get(srcNodeId).addExecutionEdge(executionEdge);
|
||||
idToExecutionNode.get(srcNodeId).addOutputEdge(executionEdge);
|
||||
idToExecutionNode.get(targetNodeId).addInputEdge(executionEdge);
|
||||
}
|
||||
|
||||
List<ExecutionNode> executionNodes = idToExecutionNode.values().stream()
|
||||
.collect(Collectors.toList());
|
||||
List<ExecutionNode> executionNodes = new ArrayList<>(idToExecutionNode.values());
|
||||
return new ExecutionGraph(executionNodes);
|
||||
}
|
||||
|
||||
private RayActor createWorker(JobVertex jobVertex) {
|
||||
switch (jobVertex.getLanguage()) {
|
||||
case PYTHON:
|
||||
return Ray.createPyActor(
|
||||
"ray.streaming.runtime.worker", "JobWorker");
|
||||
case JAVA:
|
||||
return Ray.createActor(JobWorker::new);
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported language " + jobVertex.getLanguage());
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+86
@@ -0,0 +1,86 @@
|
||||
package org.ray.streaming.runtime.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
|
||||
@SuppressWarnings("UnstableApiUsage")
|
||||
public class ReflectionUtils {
|
||||
|
||||
public static Method findMethod(Class<?> cls, String methodName) {
|
||||
List<Method> methods = findMethods(cls, methodName);
|
||||
Preconditions.checkArgument(methods.size() == 1);
|
||||
return methods.get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* For covariant return type, return the most specific method.
|
||||
* @return all methods named by {@code methodName},
|
||||
*/
|
||||
public static List<Method> findMethods(Class<?> cls, String methodName) {
|
||||
List<Class<?>> classes = new ArrayList<>();
|
||||
Class<?> clazz = cls;
|
||||
while (clazz != null) {
|
||||
classes.add(clazz);
|
||||
clazz = clazz.getSuperclass();
|
||||
}
|
||||
classes.addAll(getAllInterfaces(cls));
|
||||
if (classes.indexOf(Object.class) == -1) {
|
||||
classes.add(Object.class);
|
||||
}
|
||||
|
||||
LinkedHashMap<List<Class<?>>, Method> methods = new LinkedHashMap<>();
|
||||
for (Class<?> superClass : classes) {
|
||||
for (Method m : superClass.getDeclaredMethods()) {
|
||||
if (m.getName().equals(methodName)) {
|
||||
List<Class<?>> params = Arrays.asList(m.getParameterTypes());
|
||||
Method method = methods.get(params);
|
||||
if (method == null) {
|
||||
methods.put(params, m);
|
||||
} else {
|
||||
// for covariant return type, use the most specific method
|
||||
if (method.getReturnType().isAssignableFrom(m.getReturnType())) {
|
||||
methods.put(params, m);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new ArrayList<>(methods.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Gets a <code>List</code> of all interfaces implemented by the given
|
||||
* class and its superclasses.</p>
|
||||
* <p>The order is determined by looking through each interface in turn as
|
||||
* declared in the source file and following its hierarchy up.</p>
|
||||
*/
|
||||
public static List<Class<?>> getAllInterfaces(Class<?> cls) {
|
||||
if (cls == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
LinkedHashSet<Class<?>> interfacesFound = new LinkedHashSet<>();
|
||||
getAllInterfaces(cls, interfacesFound);
|
||||
return new ArrayList<>(interfacesFound);
|
||||
}
|
||||
|
||||
private static void getAllInterfaces(Class<?> cls, LinkedHashSet<Class<?>> interfacesFound) {
|
||||
while (cls != null) {
|
||||
Class[] interfaces = cls.getInterfaces();
|
||||
for (Class anInterface : interfaces) {
|
||||
if (!interfacesFound.contains(anInterface)) {
|
||||
interfacesFound.add(anInterface);
|
||||
getAllInterfaces(anInterface, interfacesFound);
|
||||
}
|
||||
}
|
||||
|
||||
cls = cls.getSuperclass();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
+1
@@ -2,6 +2,7 @@ package org.ray.streaming.runtime.worker;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Map;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
|
||||
+2
-2
@@ -65,7 +65,7 @@ public abstract class StreamTask implements Runnable {
|
||||
List<Collector> collectors = new ArrayList<>();
|
||||
for (ExecutionEdge edge : outputEdges) {
|
||||
Map<String, ActorId> outputActorIds = new HashMap<>();
|
||||
Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph
|
||||
Map<Integer, RayActor> taskId2Worker = executionGraph
|
||||
.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
|
||||
taskId2Worker.forEach((targetTaskId, targetActor) -> {
|
||||
String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
|
||||
@@ -91,7 +91,7 @@ public abstract class StreamTask implements Runnable {
|
||||
List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
|
||||
Map<String, ActorId> inputActorIds = new HashMap<>();
|
||||
for (ExecutionEdge edge : inputEdges) {
|
||||
Map<Integer, RayActor<JobWorker>> taskId2Worker = executionGraph
|
||||
Map<Integer, RayActor> taskId2Worker = executionGraph
|
||||
.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
|
||||
taskId2Worker.forEach((srcTaskId, srcActor) -> {
|
||||
String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());
|
||||
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
package org.ray.streaming.runtime.python;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
import static org.testng.Assert.assertTrue;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class MsgPackSerializerTest {
|
||||
|
||||
@Test
|
||||
public void testSerialize() {
|
||||
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
|
||||
Map map = new HashMap();
|
||||
List list = new ArrayList<>();
|
||||
list.add(null);
|
||||
list.add(true);
|
||||
list.add(1);
|
||||
list.add(1.0d);
|
||||
list.add("str");
|
||||
map.put("k1", "value1");
|
||||
map.put("k2", 2);
|
||||
map.put("k3", list);
|
||||
byte[] bytes = serializer.serialize(map);
|
||||
Object o = serializer.deserialize(bytes);
|
||||
assertEquals(o, map);
|
||||
|
||||
byte[] binary = {1, 2, 3, 4};
|
||||
assertTrue(Arrays.equals(
|
||||
binary, (byte[]) (serializer.deserialize(serializer.serialize(binary)))));
|
||||
}
|
||||
|
||||
}
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
package org.ray.streaming.runtime.python;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.streaming.api.stream.StreamSink;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
import org.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class PythonGatewayTest {
|
||||
|
||||
@Test
|
||||
public void testPythonGateway() {
|
||||
MsgPackSerializer serializer = new MsgPackSerializer();
|
||||
PythonGateway gateway = new PythonGateway();
|
||||
gateway.createStreamingContext();
|
||||
Map<String, String> config = new HashMap<>();
|
||||
config.put("k1", "v1");
|
||||
gateway.withConfig(serializer.serialize(config));
|
||||
byte[] mockPySource = new byte[0];
|
||||
Object source = serializer.deserialize(
|
||||
gateway.createPythonStreamSource(mockPySource));
|
||||
byte[] mockPyFunc = new byte[0];
|
||||
Object mapPyFunc = serializer.deserialize(gateway.createPyFunc(mockPyFunc));
|
||||
Object mapStream = serializer.deserialize(
|
||||
gateway.callMethod(
|
||||
serializer.serialize(Arrays.asList(source, "map", mapPyFunc))));
|
||||
byte[] mockPyPartition = new byte[0];
|
||||
Object partition = serializer.deserialize(
|
||||
gateway.createPyPartition(mockPyPartition));
|
||||
Object partitionedStream = serializer.deserialize(
|
||||
gateway.callMethod(
|
||||
serializer.serialize(Arrays.asList(mapStream, "partitionBy", partition))));
|
||||
byte[] mockSinkFunc = new byte[0];
|
||||
Object sinkPyFunc = serializer.deserialize(gateway.createPyFunc(mockSinkFunc));
|
||||
gateway.callMethod(
|
||||
serializer.serialize(Arrays.asList(partitionedStream, "sink", sinkPyFunc)));
|
||||
List<StreamSink> streamSinks = gateway.getStreamingContext().getStreamSinks();
|
||||
assertEquals(streamSinks.size(), 1);
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(streamSinks, "py_job");
|
||||
JobGraph jobGraph = jobGraphBuilder.build();
|
||||
jobGraph.printJobGraph();
|
||||
}
|
||||
}
|
||||
+10
-19
@@ -1,26 +1,20 @@
|
||||
package org.ray.streaming.runtime.schedule;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.streaming.api.context.StreamingContext;
|
||||
import org.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import org.ray.streaming.api.stream.DataStream;
|
||||
import org.ray.streaming.api.stream.DataStreamSink;
|
||||
import org.ray.streaming.api.stream.DataStreamSource;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
import org.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import org.ray.streaming.runtime.BaseUnitTest;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionEdge;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionNode;
|
||||
import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
|
||||
import org.ray.streaming.runtime.worker.JobWorker;
|
||||
import org.ray.streaming.jobgraph.JobGraph;
|
||||
import org.ray.streaming.jobgraph.JobGraphBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
@@ -32,15 +26,11 @@ public class TaskAssignerImplTest extends BaseUnitTest {
|
||||
|
||||
@Test
|
||||
public void testTaskAssignImpl() {
|
||||
Ray.init();
|
||||
JobGraph jobGraph = buildDataSyncPlan();
|
||||
|
||||
List<RayActor<JobWorker>> workers = new ArrayList<>();
|
||||
for(int i = 0; i < jobGraph.getJobVertexList().size(); i++) {
|
||||
workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom()));
|
||||
}
|
||||
|
||||
TaskAssigner taskAssigner = new TaskAssignerImpl();
|
||||
ExecutionGraph executionGraph = taskAssigner.assign(jobGraph, workers);
|
||||
ExecutionGraph executionGraph = taskAssigner.assign(jobGraph);
|
||||
|
||||
List<ExecutionNode> executionNodeList = executionGraph.getExecutionNodeList();
|
||||
|
||||
@@ -61,16 +51,17 @@ public class TaskAssignerImplTest extends BaseUnitTest {
|
||||
Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK);
|
||||
Assert.assertEquals(sinkNode.getExecutionTasks().size(), 1);
|
||||
Assert.assertEquals(sinkNode.getOutputEdges().size(), 0);
|
||||
|
||||
Ray.shutdown();
|
||||
}
|
||||
|
||||
public JobGraph buildDataSyncPlan() {
|
||||
StreamingContext streamingContext = StreamingContext.buildContext();
|
||||
DataStream<String> dataStream = DataStreamSource.buildSource(streamingContext,
|
||||
Lists.newArrayList("a", "b", "c"));
|
||||
DataStreamSink streamSink = dataStream.sink(x -> LOGGER.info(x));
|
||||
DataStreamSink streamSink = dataStream.sink(LOGGER::info);
|
||||
JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(Lists.newArrayList(streamSink));
|
||||
|
||||
JobGraph jobGraph = jobGraphBuilder.build();
|
||||
return jobGraph;
|
||||
return jobGraphBuilder.build();
|
||||
}
|
||||
}
|
||||
|
||||
+38
@@ -0,0 +1,38 @@
|
||||
package org.ray.streaming.runtime.util;
|
||||
|
||||
import static org.testng.Assert.assertEquals;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collections;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class ReflectionUtilsTest {
|
||||
|
||||
static class Foo implements Serializable {
|
||||
public void f1() {
|
||||
}
|
||||
|
||||
public void f2() {
|
||||
}
|
||||
|
||||
public void f2(boolean a1) {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFindMethod() throws NoSuchMethodException {
|
||||
assertEquals(Foo.class.getDeclaredMethod("f1"),
|
||||
ReflectionUtils.findMethod(Foo.class, "f1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFindMethods() {
|
||||
assertEquals(ReflectionUtils.findMethods(Foo.class, "f2").size(), 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetAllInterfaces() {
|
||||
assertEquals(ReflectionUtils.getAllInterfaces(Foo.class),
|
||||
Collections.singletonList(Serializable.class));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user