Revert "[Streaming] Fault Tolerance Implementation (#10008)" (#10582)

This reverts commit 1b1466748f.
This commit is contained in:
SangBin Cho
2020-09-04 13:21:18 -07:00
committed by GitHub
parent da83bbd764
commit cb919c5e5c
158 changed files with 1227 additions and 7040 deletions
@@ -1,9 +1,7 @@
package io.ray.streaming.runtime.config;
import com.google.common.base.Preconditions;
import io.ray.streaming.runtime.config.global.CheckpointConfig;
import io.ray.streaming.runtime.config.global.CommonConfig;
import io.ray.streaming.runtime.config.global.ContextBackendConfig;
import io.ray.streaming.runtime.config.global.TransferConfig;
import java.io.Serializable;
import java.lang.reflect.Method;
@@ -21,19 +19,17 @@ import org.slf4j.LoggerFactory;
public class StreamingGlobalConfig implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(StreamingGlobalConfig.class);
public final CommonConfig commonConfig;
public final TransferConfig transferConfig;
public final Map<String, String> configMap;
public CheckpointConfig checkpointConfig;
public ContextBackendConfig contextBackendConfig;
public StreamingGlobalConfig(final Map<String, String> conf) {
configMap = new HashMap<>(conf);
commonConfig = ConfigFactory.create(CommonConfig.class, conf);
transferConfig = ConfigFactory.create(TransferConfig.class, conf);
checkpointConfig = ConfigFactory.create(CheckpointConfig.class, conf);
contextBackendConfig = ConfigFactory.create(ContextBackendConfig.class, conf);
globalConfig2Map();
}
@@ -1,55 +0,0 @@
package io.ray.streaming.runtime.config.global;
import io.ray.streaming.runtime.config.Config;
import org.aeonbits.owner.Mutable;
/**
* Configurations for checkpointing.
*/
public interface CheckpointConfig extends Config, Mutable {
String CP_INTERVAL_SECS = "streaming.checkpoint.interval.secs";
String CP_TIMEOUT_SECS = "streaming.checkpoint.timeout.secs";
String CP_PREFIX_KEY_MASTER = "streaming.checkpoint.prefix-key.job-master.context";
String CP_PREFIX_KEY_WORKER = "streaming.checkpoint.prefix-key.job-worker.context";
String CP_PREFIX_KEY_OPERATOR = "streaming.checkpoint.prefix-key.job-worker.operator";
/**
* Checkpoint time interval. JobMaster won't trigger 2 checkpoint in less than this time interval.
*/
@DefaultValue(value = "5")
@Key(value = CP_INTERVAL_SECS)
int cpIntervalSecs();
/**
* How long should JobMaster wait for checkpoint to finish. When this timeout is reached and
* JobMaster hasn't received all commits from workers, JobMaster will consider this checkpoint as
* failed and trigger another checkpoint.
*/
@DefaultValue(value = "120")
@Key(value = CP_TIMEOUT_SECS)
int cpTimeoutSecs();
/**
* This is used for saving JobMaster's context to storage, user usually don't need to change this.
*/
@DefaultValue(value = "job_master_runtime_context_")
@Key(value = CP_PREFIX_KEY_MASTER)
String jobMasterContextCpPrefixKey();
/**
* This is used for saving JobWorker's context to storage, user usually don't need to change this.
*/
@DefaultValue(value = "job_worker_context_")
@Key(value = CP_PREFIX_KEY_WORKER)
String jobWorkerContextCpPrefixKey();
/**
* This is used for saving user operator(in StreamTask)'s context to storage, user usually don't
* need to change this.
*/
@DefaultValue(value = "job_worker_op_")
@Key(value = CP_PREFIX_KEY_OPERATOR)
String jobWorkerOpCpPrefixKey();
}
@@ -1,17 +0,0 @@
package io.ray.streaming.runtime.config.global;
import org.aeonbits.owner.Config;
public interface ContextBackendConfig extends Config {
String STATE_BACKEND_TYPE = "streaming.context-backend.type";
String FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root";
@Config.DefaultValue(value = "memory")
@Key(value = STATE_BACKEND_TYPE)
String stateBackendType();
@Config.DefaultValue(value = "/tmp/ray_streaming_state")
@Key(value = FILE_STATE_ROOT_PATH)
String fileStateRootPath();
}
@@ -22,6 +22,13 @@ public interface TransferConfig extends Config {
@Key(value = io.ray.streaming.util.Config.CHANNEL_SIZE)
long channelSize();
/**
* DataRead read timeout.
*/
@DefaultValue(value = "false")
@Key(value = io.ray.streaming.util.Config.IS_RECREATE)
boolean readerIsRecreate();
/**
* Return from DataReader.getBundle if only empty message read in this interval.
*/
@@ -1,22 +0,0 @@
package io.ray.streaming.runtime.config.types;
public enum ContextBackendType {
/**
* Memory type
*/
MEMORY("memory", 0),
/**
* Local File
*/
LOCAL_FILE("local_file", 1);
private String name;
private int index;
ContextBackendType(String name, int index) {
this.name = name;
this.index = index;
}
}
@@ -1,42 +0,0 @@
package io.ray.streaming.runtime.context;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.worker.JobWorker;
/**
* This interface is used for storing context of {@link JobWorker} and {@link JobMaster}.
* The checkpoint returned by user function is also saved using this interface.
*/
public interface ContextBackend {
/**
* check if key exists in state
*
* @return true if exists
*/
boolean exists(final String key) throws Exception;
/**
* get content by key
*
* @param key key
* @return the StateBackend
*/
byte[] get(final String key) throws Exception;
/**
* put content by key
*
* @param key key
* @param value content
*/
void put(final String key, final byte[] value) throws Exception;
/**
* remove content by key
*
* @param key key
*/
void remove(final String key) throws Exception;
}
@@ -1,27 +0,0 @@
package io.ray.streaming.runtime.context;
import io.ray.streaming.runtime.config.StreamingGlobalConfig;
import io.ray.streaming.runtime.config.types.ContextBackendType;
import io.ray.streaming.runtime.context.impl.AtomicFsBackend;
import io.ray.streaming.runtime.context.impl.MemoryContextBackend;
public class ContextBackendFactory {
public static ContextBackend getContextBackend(final StreamingGlobalConfig config) {
ContextBackend contextBackend;
ContextBackendType type = ContextBackendType.valueOf(
config.contextBackendConfig.stateBackendType().toUpperCase());
switch (type) {
case MEMORY:
contextBackend = new MemoryContextBackend(config.contextBackendConfig);
break;
case LOCAL_FILE:
contextBackend = new AtomicFsBackend(config.contextBackendConfig);
break;
default:
throw new RuntimeException("Unsupported context backend type.");
}
return contextBackend;
}
}
@@ -1,52 +0,0 @@
package io.ray.streaming.runtime.context;
import com.google.common.base.MoreObjects;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* This data structure contains state information of a task.
*/
public class OperatorCheckpointInfo implements Serializable {
/**
* key: channel ID, value: offset
*/
public Map<String, OffsetInfo> inputPoints;
public Map<String, OffsetInfo> outputPoints;
/**
* a serializable checkpoint returned by processor
*/
public Serializable processorCheckpoint;
public long checkpointId;
public OperatorCheckpointInfo() {
inputPoints = new HashMap<>();
outputPoints = new HashMap<>();
checkpointId = -1;
}
public OperatorCheckpointInfo(
Map<String, OffsetInfo> inputPoints,
Map<String, OffsetInfo> outputPoints,
Serializable processorCheckpoint,
long checkpointId) {
this.inputPoints = inputPoints;
this.outputPoints = outputPoints;
this.checkpointId = checkpointId;
this.processorCheckpoint = processorCheckpoint;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("inputPoints", inputPoints)
.add("outputPoints", outputPoints)
.add("processorCheckpoint", processorCheckpoint)
.add("checkpointId", checkpointId)
.toString();
}
}
@@ -1,48 +0,0 @@
package io.ray.streaming.runtime.context.impl;
import io.ray.streaming.runtime.config.global.ContextBackendConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Achieves an atomic `put` method.
* known issue: if you crashed while write a key at first time, this code will not work.
*/
public class AtomicFsBackend extends LocalFileContextBackend {
private static final Logger LOG = LoggerFactory.getLogger(AtomicFsBackend.class);
private static final String TMP_FLAG = "_tmp";
public AtomicFsBackend(final ContextBackendConfig config) {
super(config);
}
@Override
public byte[] get(String key) throws Exception {
String tmpKey = key + TMP_FLAG;
if (super.exists(tmpKey) && !super.exists(key)) {
return super.get(tmpKey);
}
return super.get(key);
}
@Override
public void put(String key, byte[] value) throws Exception {
String tmpKey = key + TMP_FLAG;
if (super.exists(tmpKey) && !super.exists(key)) {
super.rename(tmpKey, key);
}
super.put(tmpKey, value);
super.remove(key);
super.rename(tmpKey, key);
}
@Override
public void remove(String key) {
String tmpKey = key + TMP_FLAG;
if (super.exists(tmpKey)) {
super.remove(tmpKey);
}
super.remove(key);
}
}
@@ -1,55 +0,0 @@
package io.ray.streaming.runtime.context.impl;
import io.ray.streaming.runtime.config.global.ContextBackendConfig;
import io.ray.streaming.runtime.context.ContextBackend;
import java.io.File;
import org.apache.commons.io.FileUtils;
/**
* This context backend uses local file system and doesn't supports failover in cluster.
* But it supports failover in single node.
* This is a pure file system backend which doesn't support atomic writing, please don't use this
* class, instead, use {@link AtomicFsBackend} which extends this class.
*/
public class LocalFileContextBackend implements ContextBackend {
private final String rootPath;
public LocalFileContextBackend(ContextBackendConfig config) {
rootPath = config.fileStateRootPath();
}
@Override
public boolean exists(String key) {
File file = new File(rootPath, key);
return file.exists();
}
@Override
public byte[] get(String key) throws Exception {
File file = new File(rootPath, key);
if (file.exists()) {
return FileUtils.readFileToByteArray(file);
}
return null;
}
@Override
public void put(String key, byte[] value) throws Exception {
File file = new File(rootPath, key);
FileUtils.writeByteArrayToFile(file, value);
}
@Override
public void remove(String key) {
File file = new File(rootPath, key);
FileUtils.deleteQuietly(file);
}
protected void rename(String fromKey, String toKey) throws Exception {
File srcFile = new File(rootPath, fromKey);
File dstFile = new File(rootPath, toKey);
FileUtils.moveFile(srcFile, dstFile);
}
}
@@ -1,72 +0,0 @@
package io.ray.streaming.runtime.context.impl;
import io.ray.streaming.runtime.config.global.ContextBackendConfig;
import io.ray.streaming.runtime.context.ContextBackend;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This context backend uses memory and doesn't supports failover.
* Data will be lost after worker died.
*/
public class MemoryContextBackend implements ContextBackend {
private static final Logger LOG = LoggerFactory.getLogger(MemoryContextBackend.class);
private final Map<String, byte[]> kvStore = new HashMap<>();
public MemoryContextBackend(ContextBackendConfig config) {
if (LOG.isInfoEnabled()) {
LOG.info("Start init memory state backend, config is {}.", config);
LOG.info("Finish init memory state backend.");
}
}
@Override
public boolean exists(String key) {
return kvStore.containsKey(key);
}
@Override
public byte[] get(final String key) {
if (LOG.isInfoEnabled()) {
LOG.info("Get value of key {} start.", key);
}
byte[] readData = kvStore.get(key);
if (LOG.isInfoEnabled()) {
LOG.info("Get value of key {} success.", key);
}
return readData;
}
@Override
public void put(final String key, final byte[] value) {
if (LOG.isInfoEnabled()) {
LOG.info("Put value of key {} start.", key);
}
kvStore.put(key, value);
if (LOG.isInfoEnabled()) {
LOG.info("Put value of key {} success.", key);
}
}
@Override
public void remove(final String key) {
if (LOG.isInfoEnabled()) {
LOG.info("Remove value of key {} start.", key);
}
kvStore.remove(key);
if (LOG.isInfoEnabled()) {
LOG.info("Remove value of key {} success.", key);
}
}
}
@@ -9,8 +9,8 @@ import io.ray.streaming.message.Record;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.ChannelId;
import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import java.nio.ByteBuffer;
import java.util.Collection;
import org.slf4j.Logger;
@@ -1,29 +1,19 @@
package io.ray.streaming.runtime.core.graph.executiongraph;
import com.google.common.collect.Sets;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Physical plan.
*/
public class ExecutionGraph implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(ExecutionGraph.class);
/**
* Name of the job.
*/
@@ -39,27 +29,6 @@ public class ExecutionGraph implements Serializable {
*/
private Map<Integer, ExecutionJobVertex> executionJobVertexMap;
/**
* Data map for execution vertex.
* key: execution vertex id.
* value: execution vertex.
*/
private Map<Integer, ExecutionVertex> executionVertexMap;
/**
* Data map for execution vertex.
* key: actor id.
* value: execution vertex.
*/
private Map<ActorId, ExecutionVertex> actorIdExecutionVertexMap;
/**
* key: channel ID
* value: actors in both sides of this channel
*/
private Map<String, Set<BaseActorHandle>> channelGroupedActors;
/**
* The max parallelism of the whole graph.
*/
@@ -85,7 +54,7 @@ public class ExecutionGraph implements Serializable {
}
public List<ExecutionJobVertex> getExecutionJobVertexList() {
return new ArrayList<>(executionJobVertexMap.values());
return new ArrayList<ExecutionJobVertex>(executionJobVertexMap.values());
}
public Map<Integer, ExecutionJobVertex> getExecutionJobVertexMap() {
@@ -96,58 +65,6 @@ public class ExecutionGraph implements Serializable {
this.executionJobVertexMap = executionJobVertexMap;
}
/**
* generate relation mappings between actors, execution vertices and channels
* this method must be called after worker actor is set.
*/
public void generateActorMappings() {
LOG.info("Setup queue actors relation.");
channelGroupedActors = new HashMap<>();
actorIdExecutionVertexMap = new HashMap<>();
getAllExecutionVertices().forEach(curVertex -> {
// current
actorIdExecutionVertexMap.put(curVertex.getActorId(), curVertex);
// input
List<ExecutionEdge> inputEdges = curVertex.getInputEdges();
inputEdges.forEach(inputEdge -> {
ExecutionVertex inputVertex = inputEdge.getSourceExecutionVertex();
String channelId = curVertex.getChannelIdByPeerVertex(inputVertex);
addActorToChannelGroupedActors(channelGroupedActors, channelId,
inputVertex.getWorkerActor());
});
// output
List<ExecutionEdge> outputEdges = curVertex.getOutputEdges();
outputEdges.forEach(outputEdge -> {
ExecutionVertex outputVertex = outputEdge.getTargetExecutionVertex();
String channelId = curVertex.getChannelIdByPeerVertex(outputVertex);
addActorToChannelGroupedActors(channelGroupedActors, channelId,
outputVertex.getWorkerActor());
});
});
LOG.debug("Channel grouped actors is: {}.", channelGroupedActors);
}
private void addActorToChannelGroupedActors(
Map<String, Set<BaseActorHandle>> channelGroupedActors,
String queueName,
BaseActorHandle actor) {
Set<BaseActorHandle> actorSet =
channelGroupedActors.computeIfAbsent(queueName, k -> new HashSet<>());
actorSet.add(actor);
}
public void setExecutionVertexMap(Map<Integer, ExecutionVertex> executionVertexMap) {
this.executionVertexMap = executionVertexMap;
}
public Map<String, String> getJobConfig() {
return jobConfig;
}
@@ -197,73 +114,25 @@ public class ExecutionGraph implements Serializable {
return executionJobVertexMap.values().stream()
.map(ExecutionJobVertex::getExecutionVertices)
.flatMap(Collection::stream)
.filter(ExecutionVertex::is2Add)
.filter(vertex -> vertex.is2Add())
.collect(Collectors.toList());
}
/**
* Get specified execution vertex from current execution graph by execution vertex id.
*
* @param executionVertexId execution vertex id.
* @param vertexId execution vertex id.
* @return the specified execution vertex.
*/
public ExecutionVertex getExecutionVertexByExecutionVertexId(int executionVertexId) {
if (executionVertexMap.containsKey(executionVertexId)) {
return executionVertexMap.get(executionVertexId);
}
throw new RuntimeException("Vertex " + executionVertexId + " does not exist!");
}
/**
* Get specified execution vertex from current execution graph by actor id.
*
* @param actorId the actor id of execution vertex.
* @return the specified execution vertex.
*/
public ExecutionVertex getExecutionVertexByActorId(ActorId actorId) {
return actorIdExecutionVertexMap.get(actorId);
}
/**
* Get specified actor by actor id.
*
* @param actorId the actor id of execution vertex.
* @return the specified actor handle.
*/
public Optional<BaseActorHandle> getActorById(ActorId actorId) {
return getAllActors().stream()
.filter(actor -> actor.getId().equals(actorId))
.findFirst();
}
/**
* Get the peer actor in the other side of channelName of a given actor
*
* @param actor actor in this side
* @param channelName the channel name
* @return the peer actor in the other side
*/
public BaseActorHandle getPeerActor(BaseActorHandle actor, String channelName) {
Set<BaseActorHandle> set = getActorsByChannelId(channelName);
final BaseActorHandle[] res = new BaseActorHandle[1];
set.forEach(anActor -> {
if (!anActor.equals(actor)) {
res[0] = anActor;
public ExecutionVertex getExecutionJobVertexByJobVertexId(int vertexId) {
for (ExecutionJobVertex executionJobVertex : executionJobVertexMap.values()) {
for (ExecutionVertex executionVertex : executionJobVertex.getExecutionVertices()) {
if (executionVertex.getExecutionVertexId() == vertexId) {
return executionVertex;
}
}
});
return res[0];
}
/**
* Get actors in both sides of a channelId
*
* @param channelId the channelId
* @return actors in both sides
*/
public Set<BaseActorHandle> getActorsByChannelId(String channelId) {
return channelGroupedActors.getOrDefault(channelId, Sets.newHashSet());
}
throw new RuntimeException("Vertex " + vertexId + " does not exist!");
}
/**
@@ -333,27 +202,4 @@ public class ExecutionGraph implements Serializable {
.collect(Collectors.toList());
}
public Set<String> getActorName(Set<ActorId> actorIds) {
return getAllExecutionVertices().stream()
.filter(executionVertex -> actorIds.contains(executionVertex.getActorId()))
.map(ExecutionVertex::getActorName)
.collect(Collectors.toSet());
}
public String getActorName(ActorId actorId) {
Set<ActorId> set = Sets.newHashSet();
set.add(actorId);
Set<String> result = getActorName(set);
if (result.isEmpty()) {
return null;
}
return result.iterator().next();
}
public List<ActorId> getAllActorsId() {
return getAllActors().stream()
.map(BaseActorHandle::getId)
.collect(Collectors.toList());
}
}
@@ -3,12 +3,11 @@ package io.ray.streaming.runtime.core.graph.executiongraph;
import com.google.common.base.MoreObjects;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.jobgraph.JobEdge;
import java.io.Serializable;
/**
* An edge that connects two execution job vertices.
*/
public class ExecutionJobEdge implements Serializable {
public class ExecutionJobEdge {
/**
* The source(upstream) execution job vertex.
@@ -8,7 +8,6 @@ import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.jobgraph.VertexType;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.runtime.config.master.ResourceConfig;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -21,7 +20,7 @@ import org.aeonbits.owner.ConfigFactory;
* <p>Execution job vertex is the physical form of {@link JobVertex} and
* every execution job vertex is corresponding to a group of {@link ExecutionVertex}.
*/
public class ExecutionJobVertex implements Serializable {
public class ExecutionJobVertex {
/**
* Unique id. Use {@link JobVertex}'s id directly.
@@ -9,7 +9,6 @@ import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.runtime.config.master.ResourceConfig;
import io.ray.streaming.runtime.core.resource.ContainerId;
import io.ray.streaming.runtime.core.resource.ResourceType;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -61,8 +60,6 @@ public class ExecutionVertex implements Serializable {
*/
private ContainerId containerId;
private String pid;
/**
* Worker actor handle.
*/
@@ -76,14 +73,6 @@ public class ExecutionVertex implements Serializable {
private List<ExecutionEdge> inputEdges = new ArrayList<>();
private List<ExecutionEdge> outputEdges = new ArrayList<>();
private transient List<String> outputChannelIdList;
private transient List<String> inputChannelIdList;
private transient List<BaseActorHandle> outputActorList;
private transient List<BaseActorHandle> inputActorList;
private Map<Integer, String> exeVertexChannelMap;
public ExecutionVertex(
int globalIndex,
int index,
@@ -103,7 +92,9 @@ public class ExecutionVertex implements Serializable {
}
private Map<String, String> genWorkerConfig(Map<String, String> jobConfig) {
return new HashMap<>(jobConfig);
Map<String, String> workerConfig = new HashMap<>();
workerConfig.putAll(jobConfig);
return workerConfig;
}
public int getExecutionVertexId() {
@@ -170,14 +161,14 @@ public class ExecutionVertex implements Serializable {
return workerActor;
}
public void setWorkerActor(BaseActorHandle workerActor) {
this.workerActor = workerActor;
}
public ActorId getWorkerActorId() {
return workerActor.getId();
}
public void setWorkerActor(BaseActorHandle workerActor) {
this.workerActor = workerActor;
}
public List<ExecutionEdge> getInputEdges() {
return inputEdges;
}
@@ -208,14 +199,6 @@ public class ExecutionVertex implements Serializable {
.collect(Collectors.toList());
}
public ActorId getActorId() {
return null == workerActor ? null : workerActor.getId();
}
public String getActorName() {
return String.valueOf(executionVertexId);
}
public Map<String, Double> getResource() {
return resource;
}
@@ -236,89 +219,12 @@ public class ExecutionVertex implements Serializable {
this.containerId = containerId;
}
public String getPid() {
return pid;
}
public void setPid(String pid) {
this.pid = pid;
}
public void setContainerIfNotExist(ContainerId containerId) {
if (null == this.containerId) {
this.containerId = containerId;
}
}
/*---------channel-actor relations---------*/
public List<String> getOutputChannelIdList() {
if (outputChannelIdList == null) {
generateActorChannelInfo();
}
return outputChannelIdList;
}
public List<BaseActorHandle> getOutputActorList() {
if (outputActorList == null) {
generateActorChannelInfo();
}
return outputActorList;
}
public List<String> getInputChannelIdList() {
if (inputChannelIdList == null) {
generateActorChannelInfo();
}
return inputChannelIdList;
}
public List<BaseActorHandle> getInputActorList() {
if (inputActorList == null) {
generateActorChannelInfo();
}
return inputActorList;
}
public String getChannelIdByPeerVertex(ExecutionVertex peerVertex) {
if (exeVertexChannelMap == null) {
generateActorChannelInfo();
}
return exeVertexChannelMap.get(peerVertex.getExecutionVertexId());
}
private void generateActorChannelInfo() {
inputChannelIdList = new ArrayList<>();
inputActorList = new ArrayList<>();
outputChannelIdList = new ArrayList<>();
outputActorList = new ArrayList<>();
exeVertexChannelMap = new HashMap<>();
List<ExecutionEdge> inputEdges = getInputEdges();
for (ExecutionEdge edge : inputEdges) {
String channelId = ChannelId.genIdStr(
edge.getSourceExecutionVertex().getExecutionVertexId(),
getExecutionVertexId(),
getBuildTime());
inputChannelIdList.add(channelId);
inputActorList.add(edge.getSourceExecutionVertex().getWorkerActor());
exeVertexChannelMap.put(edge.getSourceExecutionVertex().getExecutionVertexId(), channelId);
}
List<ExecutionEdge> outputEdges = getOutputEdges();
for (ExecutionEdge edge : outputEdges) {
String channelId = ChannelId.genIdStr(
getExecutionVertexId(),
edge.getTargetExecutionVertex().getExecutionVertexId(),
getBuildTime());
outputChannelIdList.add(channelId);
outputActorList.add(edge.getTargetExecutionVertex().getWorkerActor());
exeVertexChannelMap.put(edge.getTargetExecutionVertex().getExecutionVertexId(), channelId);
}
}
private Map<String, Double> generateResources(ResourceConfig resourceConfig) {
Map<String, Double> resourceMap = new HashMap<>();
if (resourceConfig.isTaskCpuResourceLimit()) {
@@ -15,7 +15,7 @@ public class ProcessBuilder {
public static StreamProcessor buildProcessor(StreamOperator streamOperator) {
OperatorType type = streamOperator.getOpType();
LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type,
streamOperator.getClass().getSimpleName());
streamOperator.getClass().getSimpleName().toString());
switch (type) {
case SOURCE:
return new SourceProcessor<>((SourceOperator) streamOperator);
@@ -2,7 +2,6 @@ package io.ray.streaming.runtime.core.processor;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.function.Function;
import java.io.Serializable;
import java.util.List;
@@ -12,15 +11,5 @@ public interface Processor<T> extends Serializable {
void process(T t);
/**
* See {@link Function#saveCheckpoint()}.
*/
Serializable saveCheckpoint();
/**
* See {@link Function#loadCheckpoint(Serializable)}.
*/
void loadCheckpoint(Serializable checkpointObject);
void close();
}
@@ -19,8 +19,8 @@ public class SourceProcessor<T> extends StreamProcessor<Record, SourceOperator<T
throw new UnsupportedOperationException("SourceProcessor should not process record");
}
public void fetch() {
operator.fetch();
public void run() {
operator.run();
}
@Override
@@ -3,7 +3,6 @@ package io.ray.streaming.runtime.core.processor;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.operator.Operator;
import java.io.Serializable;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,16 +35,6 @@ public abstract class StreamProcessor<T, P extends Operator> implements Processo
LOGGER.info("opened {}", this);
}
@Override
public Serializable saveCheckpoint() {
return operator.saveCheckpoint();
}
@Override
public void loadCheckpoint(Serializable checkpointObject) {
operator.loadCheckpoint(checkpointObject);
}
@Override
public String toString() {
return this.getClass().getSimpleName();
@@ -1,36 +1,18 @@
package io.ray.streaming.runtime.master;
import com.google.common.base.Preconditions;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.runtime.config.StreamingConfig;
import io.ray.streaming.runtime.config.StreamingMasterConfig;
import io.ray.streaming.runtime.context.ContextBackend;
import io.ray.streaming.runtime.context.ContextBackendFactory;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.coordinator.CheckpointCoordinator;
import io.ray.streaming.runtime.master.coordinator.FailoverCoordinator;
import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
import io.ray.streaming.runtime.master.graphmanager.GraphManager;
import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl;
import io.ray.streaming.runtime.master.resourcemanager.ResourceManager;
import io.ray.streaming.runtime.master.resourcemanager.ResourceManagerImpl;
import io.ray.streaming.runtime.master.scheduler.JobSchedulerImpl;
import io.ray.streaming.runtime.util.CheckpointStateUtil;
import io.ray.streaming.runtime.util.ResourceUtil;
import io.ray.streaming.runtime.util.Serializer;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -42,68 +24,33 @@ public class JobMaster {
private static final Logger LOG = LoggerFactory.getLogger(JobMaster.class);
private JobMasterRuntimeContext runtimeContext;
private JobRuntimeContext runtimeContext;
private ResourceManager resourceManager;
private JobSchedulerImpl scheduler;
private GraphManager graphManager;
private StreamingMasterConfig conf;
private ContextBackend contextBackend;
private ActorHandle<JobMaster> jobMasterActor;
// coordinators
private CheckpointCoordinator checkpointCoordinator;
private FailoverCoordinator failoverCoordinator;
public JobMaster(Map<String, String> confMap) {
LOG.info("Creating job master with conf: {}.", confMap);
StreamingConfig streamingConfig = new StreamingConfig(confMap);
this.conf = streamingConfig.masterConfig;
this.contextBackend = ContextBackendFactory.getContextBackend(this.conf);
// init runtime context
runtimeContext = new JobMasterRuntimeContext(streamingConfig);
// load checkpoint if is recover
if (Ray.getRuntimeContext().wasCurrentActorRestarted()) {
loadMasterCheckpoint();
}
runtimeContext = new JobRuntimeContext(streamingConfig);
LOG.info("Finished creating job master.");
}
public static String getJobMasterRuntimeContextKey(StreamingMasterConfig conf) {
return conf.checkpointConfig.jobMasterContextCpPrefixKey() + conf.commonConfig.jobName();
}
private void loadMasterCheckpoint() {
LOG.info("Start to load JobMaster's checkpoint.");
// recover runtime context
byte[] bytes =
CheckpointStateUtil.get(contextBackend, getJobMasterRuntimeContextKey(getConf()));
if (bytes == null) {
LOG.warn("JobMaster got empty checkpoint from state backend. Skip loading checkpoint.");
// cp 0 was automatically saved when job started, see StreamTask.
runtimeContext.checkpointIds.add(0L);
return;
}
this.runtimeContext = Serializer.decode(bytes);
// FO case, triggered by ray, we need to register context when loading checkpoint
LOG.info("JobMaster recover runtime context[{}] from state backend.", runtimeContext);
init(true);
}
/**
* Init JobMaster. To initiate or recover other components(like metrics and extra coordinators).
*
* @return init result
*/
public Boolean init(boolean isRecover) {
LOG.info("Initializing job master, isRecover={}.", isRecover);
public Boolean init() {
LOG.info("Initializing job master.");
if (this.runtimeContext.getExecutionGraph() == null) {
LOG.error("Init job master failed. Job graphs is null.");
@@ -113,14 +60,6 @@ public class JobMaster {
ExecutionGraph executionGraph = graphManager.getExecutionGraph();
Preconditions.checkArgument(executionGraph != null, "no execution graph");
// init coordinators
checkpointCoordinator = new CheckpointCoordinator(this);
checkpointCoordinator.start();
failoverCoordinator = new FailoverCoordinator(this, isRecover);
failoverCoordinator.start();
saveContext();
LOG.info("Finished initializing job master.");
return true;
}
@@ -162,86 +101,11 @@ public class JobMaster {
return true;
}
public synchronized void saveContext() {
if (runtimeContext != null && getConf() != null) {
LOG.debug("Save JobMaster context.");
byte[] contextBytes = Serializer.encode(runtimeContext);
CheckpointStateUtil
.put(contextBackend, getJobMasterRuntimeContextKey(getConf()), contextBytes);
}
}
public byte[] reportJobWorkerCommit(byte[] reportBytes) {
Boolean ret = false;
RemoteCall.BaseWorkerCmd reportPb;
try {
reportPb = RemoteCall.BaseWorkerCmd.parseFrom(reportBytes);
ActorId actorId = ActorId.fromBytes(reportPb.getActorId().toByteArray());
long remoteCallCost = System.currentTimeMillis() - reportPb.getTimestamp();
LOG.info("Vertex {}, request job worker commit cost {}ms, actorId={}.",
getExecutionVertex(actorId), remoteCallCost, actorId);
RemoteCall.WorkerCommitReport commit =
reportPb.getDetail().unpack(RemoteCall.WorkerCommitReport.class);
WorkerCommitReport report = new WorkerCommitReport(actorId, commit.getCommitCheckpointId());
ret = checkpointCoordinator.reportJobWorkerCommit(report);
} catch (InvalidProtocolBufferException e) {
LOG.error("Parse job worker commit has exception.", e);
}
return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray();
}
public byte[] requestJobWorkerRollback(byte[] requestBytes) {
Boolean ret = false;
RemoteCall.BaseWorkerCmd requestPb;
try {
requestPb = RemoteCall.BaseWorkerCmd.parseFrom(requestBytes);
ActorId actorId = ActorId.fromBytes(requestPb.getActorId().toByteArray());
long remoteCallCost = System.currentTimeMillis() - requestPb.getTimestamp();
ExecutionGraph executionGraph = graphManager.getExecutionGraph();
Optional<BaseActorHandle> rayActor = executionGraph.getActorById(actorId);
if (!rayActor.isPresent()) {
LOG.warn("Skip this invalid rollback, actor id {} is not found.", actorId);
return RemoteCall.BoolResult.newBuilder().setBoolRes(false).build().toByteArray();
}
ExecutionVertex exeVertex = getExecutionVertex(actorId);
LOG.info("Vertex {}, request job worker rollback cost {}ms, actorId={}.",
exeVertex, remoteCallCost, actorId);
RemoteCall.WorkerRollbackRequest rollbackPb
= RemoteCall.WorkerRollbackRequest.parseFrom(requestPb.getDetail().getValue());
exeVertex.setPid(rollbackPb.getWorkerPid());
// To find old container where slot is located in.
String hostname = "";
Optional<Container> container = ResourceUtil.getContainerById(
resourceManager.getRegisteredContainers(),
exeVertex.getContainerId()
);
if (container.isPresent()) {
hostname = container.get().getHostname();
}
WorkerRollbackRequest request = new WorkerRollbackRequest(
actorId, rollbackPb.getExceptionMsg(), hostname, exeVertex.getPid()
);
ret = failoverCoordinator.requestJobWorkerRollback(request);
LOG.info("Vertex {} request rollback, exception msg : {}.",
exeVertex, rollbackPb.getExceptionMsg());
} catch (Throwable e) {
LOG.error("Parse job worker rollback has exception.", e);
}
return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray();
}
private ExecutionVertex getExecutionVertex(ActorId id) {
return graphManager.getExecutionGraph().getExecutionVertexByActorId(id);
}
public ActorHandle<JobMaster> getJobMasterActor() {
return jobMasterActor;
}
public JobMasterRuntimeContext getRuntimeContext() {
public JobRuntimeContext getRuntimeContext() {
return runtimeContext;
}
@@ -1,81 +0,0 @@
package io.ray.streaming.runtime.master.context;
import com.google.common.base.MoreObjects;
import com.google.common.collect.Sets;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.runtime.config.StreamingConfig;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
/**
* Runtime context for job master, which will be stored in backend when saving checkpoint.
*
* <p>Including: graph, resource, checkpoint info, etc.
*/
public class JobMasterRuntimeContext implements Serializable {
/*--------------Checkpoint----------------*/
public volatile List<Long> checkpointIds = new ArrayList<>();
public volatile long lastCheckpointId = 0;
public volatile long lastCpTimestamp = 0;
public volatile BlockingQueue<BaseWorkerCmd> cpCmds = new LinkedBlockingQueue<>();
/*--------------Failover----------------*/
public volatile BlockingQueue<BaseWorkerCmd> foCmds = new ArrayBlockingQueue<>(8192);
public volatile Set<BaseWorkerCmd> unfinishedFoCmds = Sets.newConcurrentHashSet();
private StreamingConfig conf;
private JobGraph jobGraph;
private volatile ExecutionGraph executionGraph;
public JobMasterRuntimeContext(StreamingConfig conf) {
this.conf = conf;
}
public String getJobName() {
return conf.masterConfig.commonConfig.jobName();
}
public StreamingConfig getConf() {
return conf;
}
public JobGraph getJobGraph() {
return jobGraph;
}
public void setJobGraph(JobGraph jobGraph) {
this.jobGraph = jobGraph;
}
public ExecutionGraph getExecutionGraph() {
return executionGraph;
}
public void setExecutionGraph(ExecutionGraph executionGraph) {
this.executionGraph = executionGraph;
}
public Long getLastValidCheckpointId() {
if (checkpointIds.isEmpty()) {
// OL is invalid checkpoint id, worker will pass it
return 0L;
}
return checkpointIds.get(checkpointIds.size() - 1);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("jobGraph", jobGraph)
.add("executionGraph", executionGraph)
.add("conf", conf.getMap())
.toString();
}
}
@@ -1,44 +0,0 @@
package io.ray.streaming.runtime.master.coordinator;
import io.ray.api.Ray;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.graphmanager.GraphManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class BaseCoordinator implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(BaseCoordinator.class);
protected final JobMaster jobMaster;
protected final JobMasterRuntimeContext runtimeContext;
protected final GraphManager graphManager;
protected volatile boolean closed;
private Thread thread;
public BaseCoordinator(JobMaster jobMaster) {
this.jobMaster = jobMaster;
this.runtimeContext = jobMaster.getRuntimeContext();
this.graphManager = jobMaster.getGraphManager();
}
public void start() {
thread = new Thread(Ray.wrapRunnable(this),
this.getClass().getName() + "-" + System.currentTimeMillis());
thread.start();
}
public void stop() {
closed = true;
try {
if (thread != null) {
thread.join(30000);
}
} catch (InterruptedException e) {
LOG.error("Coordinator thread exit has exception.", e);
}
}
}
@@ -1,215 +0,0 @@
package io.ray.streaming.runtime.master.coordinator;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.id.ActorId;
import io.ray.runtime.exception.RayException;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
import io.ray.streaming.runtime.rpc.RemoteCallWorker;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* CheckpointCoordinator is the controller of checkpoint, responsible for triggering checkpoint,
* collecting {@link JobWorker}'s reports and calling {@link JobWorker} to clear expired
* checkpoints when new checkpoint finished.
*/
public class CheckpointCoordinator extends BaseCoordinator {
private static final Logger LOG = LoggerFactory.getLogger(CheckpointCoordinator.class);
private final Set<ActorId> pendingCheckpointActors = new HashSet<>();
private final Set<Long> interruptedCheckpointSet = new HashSet<>();
private final int cpIntervalSecs;
private final int cpTimeoutSecs;
public CheckpointCoordinator(JobMaster jobMaster) {
super(jobMaster);
// get checkpoint interval from conf
this.cpIntervalSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpIntervalSecs();
this.cpTimeoutSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpTimeoutSecs();
// Trigger next checkpoint in interval by reset last checkpoint timestamp.
runtimeContext.lastCpTimestamp = System.currentTimeMillis();
}
@Override
public void run() {
while (!closed) {
try {
final BaseWorkerCmd command = runtimeContext.cpCmds.poll(1, TimeUnit.SECONDS);
if (command != null) {
if (command instanceof WorkerCommitReport) {
processCommitReport((WorkerCommitReport) command);
} else {
interruptCheckpoint();
}
}
if (!pendingCheckpointActors.isEmpty()) {
// if wait commit report timeout, this cp fail, and restart next cp
if (timeoutOnWaitCheckpoint()) {
LOG.warn("Waiting for checkpoint {} timeout, pending cp actors is {}.",
runtimeContext.lastCheckpointId,
graphManager.getExecutionGraph().getActorName(pendingCheckpointActors));
interruptCheckpoint();
}
} else {
maybeTriggerCheckpoint();
}
} catch (Throwable e) {
LOG.error("Checkpoint coordinator occur err.", e);
try {
interruptCheckpoint();
} catch (Throwable interruptE) {
LOG.error("Ignore interrupt checkpoint exception in catch block.");
}
}
}
LOG.warn("Checkpoint coordinator thread exit.");
}
public Boolean reportJobWorkerCommit(WorkerCommitReport report) {
LOG.info("Report job worker commit {}.", report);
Boolean ret = runtimeContext.cpCmds.offer(report);
if (!ret) {
LOG.warn("Report job worker commit failed, because command queue is full.");
}
return ret;
}
private void processCommitReport(WorkerCommitReport commitReport) {
LOG.info("Start process commit report {}, from actor name={}.", commitReport,
graphManager.getExecutionGraph().getActorName(commitReport.fromActorId));
try {
Preconditions.checkArgument(
commitReport.commitCheckpointId == runtimeContext.lastCheckpointId,
"expect checkpointId %s, but got %s",
runtimeContext.lastCheckpointId, commitReport);
if (!pendingCheckpointActors.contains(commitReport.fromActorId)) {
LOG.warn("Invalid commit report, skipped.");
return;
}
pendingCheckpointActors.remove(commitReport.fromActorId);
LOG.info("Pending actors after this commit: {}.",
graphManager.getExecutionGraph().getActorName(pendingCheckpointActors));
// checkpoint finish
if (pendingCheckpointActors.isEmpty()) {
// actor finish
runtimeContext.checkpointIds.add(runtimeContext.lastCheckpointId);
if (clearExpiredCpStateAndQueueMsg()) {
// save master context
jobMaster.saveContext();
LOG.info("Finish checkpoint: {}.", runtimeContext.lastCheckpointId);
} else {
LOG.warn("Fail to do checkpoint: {}.", runtimeContext.lastCheckpointId);
}
}
LOG.info("Process commit report {} success.", commitReport);
} catch (Throwable e) {
LOG.warn("Process commit report has exception.", e);
}
}
private void triggerCheckpoint() {
interruptedCheckpointSet.clear();
if (LOG.isInfoEnabled()) {
LOG.info("Start trigger checkpoint {}.", runtimeContext.lastCheckpointId + 1);
}
List<ActorId> allIds = graphManager.getExecutionGraph().getAllActorsId();
// do the checkpoint
pendingCheckpointActors.addAll(allIds);
// inc last checkpoint id
++runtimeContext.lastCheckpointId;
final List<ObjectRef> sourcesRet = new ArrayList<>();
graphManager.getExecutionGraph().getSourceActors().forEach(actor -> {
sourcesRet.add(RemoteCallWorker.triggerCheckpoint(
actor, runtimeContext.lastCheckpointId));
});
for (ObjectRef rayObject : sourcesRet) {
if (rayObject.get() instanceof RayException) {
LOG.warn("Trigger checkpoint has exception.", (RayException) rayObject.get());
throw (RayException) rayObject.get();
}
}
runtimeContext.lastCpTimestamp = System.currentTimeMillis();
LOG.info("Trigger checkpoint success.");
}
private void interruptCheckpoint() {
// notify checkpoint timeout is time-consuming while many workers crash or
// container failover.
if (interruptedCheckpointSet.contains(runtimeContext.lastCheckpointId)) {
LOG.warn("Skip interrupt duplicated checkpoint id : {}.", runtimeContext.lastCheckpointId);
return;
}
interruptedCheckpointSet.add(runtimeContext.lastCheckpointId);
LOG.warn("Interrupt checkpoint, checkpoint id : {}.", runtimeContext.lastCheckpointId);
List<BaseActorHandle> allActor = graphManager.getExecutionGraph().getAllActors();
if (runtimeContext.lastCheckpointId > runtimeContext.getLastValidCheckpointId()) {
RemoteCallWorker
.notifyCheckpointTimeoutParallel(allActor, runtimeContext.lastCheckpointId);
}
if (!pendingCheckpointActors.isEmpty()) {
pendingCheckpointActors.clear();
}
maybeTriggerCheckpoint();
}
private void maybeTriggerCheckpoint() {
if (readyToTrigger()) {
triggerCheckpoint();
}
}
private boolean clearExpiredCpStateAndQueueMsg() {
// queue msg must clear when first checkpoint finish
List<BaseActorHandle> allActor = graphManager.getExecutionGraph().getAllActors();
if (1 == runtimeContext.checkpointIds.size()) {
Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0);
RemoteCallWorker.clearExpiredCheckpointParallel(allActor, 0L, msgExpiredCheckpointId);
}
if (runtimeContext.checkpointIds.size() > 1) {
Long stateExpiredCpId = runtimeContext.checkpointIds.remove(0);
Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0);
RemoteCallWorker
.clearExpiredCheckpointParallel(allActor, stateExpiredCpId, msgExpiredCheckpointId);
}
return true;
}
private boolean readyToTrigger() {
return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >=
cpIntervalSecs * 1000;
}
private boolean timeoutOnWaitCheckpoint() {
return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpTimeoutSecs * 1000;
}
}
@@ -1,281 +0,0 @@
package io.ray.streaming.runtime.master.coordinator;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
import io.ray.streaming.runtime.master.coordinator.command.InterruptCheckpointRequest;
import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
import io.ray.streaming.runtime.rpc.async.AsyncRemoteCaller;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import io.ray.streaming.runtime.util.ResourceUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import org.apache.commons.collections.map.DefaultedMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class FailoverCoordinator extends BaseCoordinator {
private static final Logger LOG = LoggerFactory.getLogger(FailoverCoordinator.class);
private static final int ROLLBACK_RETRY_TIME_MS = 10 * 1000;
private final Object cmdLock = new Object();
private final AsyncRemoteCaller asyncRemoteCaller;
private long currentCascadingGroupId = 0;
private final Map<ExecutionVertex, Boolean> isRollbacking =
DefaultedMap.decorate(new ConcurrentHashMap<ExecutionVertex, Boolean>(), false);
public FailoverCoordinator(JobMaster jobMaster, boolean isRecover) {
this(jobMaster, new AsyncRemoteCaller(), isRecover);
}
public FailoverCoordinator(
JobMaster jobMaster, AsyncRemoteCaller asyncRemoteCaller,
boolean isRecover) {
super(jobMaster);
this.asyncRemoteCaller = asyncRemoteCaller;
// recover unfinished FO commands
JobMasterRuntimeContext runtimeContext = jobMaster.getRuntimeContext();
if (isRecover) {
runtimeContext.foCmds.addAll(runtimeContext.unfinishedFoCmds);
}
runtimeContext.unfinishedFoCmds.clear();
}
@Override
public void run() {
while (!closed) {
try {
final BaseWorkerCmd command;
// see rollback() for lock reason
synchronized (cmdLock) {
command = jobMaster.getRuntimeContext().foCmds.poll(1, TimeUnit.SECONDS);
}
if (null == command) {
continue;
}
if (command instanceof WorkerRollbackRequest) {
jobMaster.getRuntimeContext().unfinishedFoCmds.add(command);
dealWithRollbackRequest((WorkerRollbackRequest) command);
}
} catch (Throwable e) {
LOG.error("Fo coordinator occur err.", e);
}
}
LOG.warn("Fo coordinator thread exit.");
}
private Boolean isDuplicateRequest(WorkerRollbackRequest request) {
try {
Object[] foCmdsArray = runtimeContext.foCmds.toArray();
for (Object cmd : foCmdsArray) {
if (request.fromActorId.equals(((BaseWorkerCmd) cmd).fromActorId)) {
return true;
}
}
} catch (Exception e) {
LOG.warn("Check request is duplicated failed.", e);
}
return false;
}
public Boolean requestJobWorkerRollback(WorkerRollbackRequest request) {
LOG.info("Request job worker rollback {}.", request);
boolean ret;
if (!isDuplicateRequest(request)) {
ret = runtimeContext.foCmds.offer(request);
} else {
LOG.warn("Skip duplicated worker rollback request, {}.", request.toString());
return true;
}
jobMaster.saveContext();
if (!ret) {
LOG.warn("Request job worker rollback failed, because command queue is full.");
}
return ret;
}
private void dealWithRollbackRequest(WorkerRollbackRequest rollbackRequest) {
LOG.info("Start deal with rollback request {}.", rollbackRequest);
ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest);
// Reset pid for new-rollback actor.
if (null != rollbackRequest.getPid() &&
!rollbackRequest.getPid().equals(WorkerRollbackRequest.DEFAULT_PID)) {
exeVertex.setPid(rollbackRequest.getPid());
}
if (isRollbacking.get(exeVertex)) {
LOG.info("Vertex {} is rollbacking, skip rollback again.", exeVertex);
return;
}
String hostname = "";
Optional<Container> container = ResourceUtil.getContainerById(
jobMaster.getResourceManager().getRegisteredContainers(),
exeVertex.getContainerId()
);
if (container.isPresent()) {
hostname = container.get().getHostname();
}
if (rollbackRequest.isForcedRollback) {
interruptCheckpointAndRollback(rollbackRequest);
} else {
asyncRemoteCaller.checkIfNeedRollbackAsync(exeVertex.getWorkerActor(), res -> {
if (!res) {
LOG.info("Vertex {} doesn't need to rollback, skip it.", exeVertex);
return;
}
interruptCheckpointAndRollback(rollbackRequest);
}, throwable -> {
LOG.error("Exception when calling checkIfNeedRollbackAsync, maybe vertex is dead" +
", ignore this request, vertex={}.", exeVertex, throwable);
});
}
LOG.info("Deal with rollback request {} success.", rollbackRequest);
}
private void interruptCheckpointAndRollback(WorkerRollbackRequest rollbackRequest) {
// assign a cascadingGroupId
if (rollbackRequest.cascadingGroupId == null) {
rollbackRequest.cascadingGroupId = currentCascadingGroupId++;
}
// get last valid checkpoint id then call worker rollback
rollback(jobMaster.getRuntimeContext().getLastValidCheckpointId(), rollbackRequest,
currentCascadingGroupId);
// we interrupt current checkpoint for 2 considerations:
// 1. current checkpoint might be timeout, because barrier might be lost after failover. so we
// interrupt current checkpoint to avoid waiting.
// 2. when we want to rollback vertex to n, job finished checkpoint n+1 and cleared state
// of checkpoint n.
jobMaster.getRuntimeContext().cpCmds.offer(new InterruptCheckpointRequest());
}
/**
* call worker rollback, and deal with it's reports. callback won't be finished until
* the entire DAG back to normal.
*
* @param checkpointId checkpointId to be rollback
* @param rollbackRequest worker rollback request
* @param cascadingGroupId all rollback of a cascading group should have same ID
*/
private void rollback(
long checkpointId, WorkerRollbackRequest rollbackRequest,
long cascadingGroupId) {
ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest);
LOG.info("Call vertex {} to rollback, checkpoint id is {}, cascadingGroupId={}.",
exeVertex, checkpointId, cascadingGroupId);
isRollbacking.put(exeVertex, true);
asyncRemoteCaller.rollback(exeVertex.getWorkerActor(), checkpointId, result -> {
List<WorkerRollbackRequest> newRollbackRequests = new ArrayList<>();
switch (result.getResultEnum()) {
case SUCCESS:
ChannelRecoverInfo recoverInfo = result.getResultObj();
LOG.info("Vertex {} rollback done, dataLostQueues={}, msg={}, cascadingGroupId={}.",
exeVertex, recoverInfo.getDataLostQueues(), result.getResultMsg(), cascadingGroupId);
// rollback upstream if vertex reports abnormal input queues
newRollbackRequests =
cascadeUpstreamActors(recoverInfo.getDataLostQueues(), exeVertex, cascadingGroupId);
break;
case SKIPPED:
LOG.info("Vertex skip rollback, result = {}, cascadingGroupId={}.", result,
cascadingGroupId);
break;
default:
LOG.error(
"Rollback vertex {} failed, result={}, cascadingGroupId={}," +
" rollback this worker again after {} ms.",
exeVertex, result, cascadingGroupId, ROLLBACK_RETRY_TIME_MS);
Thread.sleep(ROLLBACK_RETRY_TIME_MS);
LOG.info("Add rollback request for {} again, cascadingGroupId={}.", exeVertex,
cascadingGroupId);
newRollbackRequests.add(
new WorkerRollbackRequest(exeVertex, "", "Rollback failed, try again.", false)
);
break;
}
// lock to avoid executing new rollback requests added.
// consider such a case: A->B->C, C cascade B, and B cascade A
// if B is rollback before B's rollback request is saved, and then JobMaster crashed,
// then A will never be rollback.
synchronized (cmdLock) {
jobMaster.getRuntimeContext().foCmds.addAll(newRollbackRequests);
// this rollback request is finished, remove it.
jobMaster.getRuntimeContext().unfinishedFoCmds.remove(rollbackRequest);
jobMaster.saveContext();
}
isRollbacking.put(exeVertex, false);
}, throwable -> {
LOG.error("Exception when calling vertex to rollback, vertex={}.", exeVertex, throwable);
isRollbacking.put(exeVertex, false);
});
LOG.info("Finish rollback vertex {}, checkpoint id is {}.", exeVertex, checkpointId);
}
private List<WorkerRollbackRequest> cascadeUpstreamActors(
Set<String> dataLostQueues, ExecutionVertex fromVertex, long cascadingGroupId) {
List<WorkerRollbackRequest> cascadedRollbackRequest = new ArrayList<>();
// rollback upstream if vertex reports abnormal input queues
dataLostQueues.forEach(q -> {
BaseActorHandle upstreamActor =
graphManager.getExecutionGraph().getPeerActor(fromVertex.getWorkerActor(), q);
ExecutionVertex upstreamExeVertex = getExecutionVertex(upstreamActor);
// vertexes that has already cascaded by other vertex in the same level
// of graph should be ignored.
if (isRollbacking.get(upstreamExeVertex)) {
return;
}
LOG.info("Call upstream vertex {} of vertex {} to rollback, cascadingGroupId={}.",
upstreamExeVertex, fromVertex, cascadingGroupId);
String hostname = "";
Optional<Container> container = ResourceUtil.getContainerById(
jobMaster.getResourceManager().getRegisteredContainers(),
upstreamExeVertex.getContainerId()
);
if (container.isPresent()) {
hostname = container.get().getHostname();
}
// force upstream vertexes to rollback
WorkerRollbackRequest upstreamRequest = new WorkerRollbackRequest(
upstreamExeVertex, hostname, String.format("Cascading rollback from %s", fromVertex), true
);
upstreamRequest.cascadingGroupId = cascadingGroupId;
cascadedRollbackRequest.add(upstreamRequest);
});
return cascadedRollbackRequest;
}
private ExecutionVertex getExeVertexFromRequest(WorkerRollbackRequest rollbackRequest) {
ActorId actorId = rollbackRequest.fromActorId;
Optional<BaseActorHandle> rayActor = graphManager.getExecutionGraph().getActorById(actorId);
if (!rayActor.isPresent()) {
throw new RuntimeException("Can not find ray actor of ID " + actorId);
}
return getExecutionVertex(rollbackRequest.fromActorId);
}
private ExecutionVertex getExecutionVertex(BaseActorHandle actor) {
return graphManager.getExecutionGraph().getExecutionVertexByActorId(actor.getId());
}
private ExecutionVertex getExecutionVertex(ActorId actorId) {
return graphManager.getExecutionGraph().getExecutionVertexByActorId(actorId);
}
}
@@ -1,17 +0,0 @@
package io.ray.streaming.runtime.master.coordinator.command;
import io.ray.api.id.ActorId;
import java.io.Serializable;
public abstract class BaseWorkerCmd implements Serializable {
public ActorId fromActorId;
public BaseWorkerCmd() {
}
protected BaseWorkerCmd(ActorId actorId) {
this.fromActorId = actorId;
}
}
@@ -1,5 +0,0 @@
package io.ray.streaming.runtime.master.coordinator.command;
public final class InterruptCheckpointRequest extends BaseWorkerCmd {
}
@@ -1,22 +0,0 @@
package io.ray.streaming.runtime.master.coordinator.command;
import com.google.common.base.MoreObjects;
import io.ray.api.id.ActorId;
public final class WorkerCommitReport extends BaseWorkerCmd {
public final long commitCheckpointId;
public WorkerCommitReport(ActorId actorId, long commitCheckpointId) {
super(actorId);
this.commitCheckpointId = commitCheckpointId;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("commitCheckpointId", commitCheckpointId)
.add("fromActorId", fromActorId)
.toString();
}
}
@@ -1,63 +0,0 @@
package io.ray.streaming.runtime.master.coordinator.command;
import com.google.common.base.MoreObjects;
import io.ray.api.id.ActorId;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
public final class WorkerRollbackRequest extends BaseWorkerCmd {
public static String DEFAULT_PID = "UNKNOWN_PID";
public Long cascadingGroupId = null;
public boolean isForcedRollback = false;
private String exceptionMsg = "No detail message.";
private String hostname = "UNKNOWN_HOST";
private String pid = DEFAULT_PID;
public WorkerRollbackRequest(ActorId actorId) {
super(actorId);
}
public WorkerRollbackRequest(ActorId actorId, String msg) {
super(actorId);
exceptionMsg = msg;
}
public WorkerRollbackRequest(
ExecutionVertex executionVertex,
String hostname,
String msg,
boolean isForcedRollback) {
super(executionVertex.getWorkerActorId());
this.hostname = hostname;
this.pid = executionVertex.getPid();
this.exceptionMsg = msg;
this.isForcedRollback = isForcedRollback;
}
public WorkerRollbackRequest(ActorId actorId, String msg, String hostname, String pid) {
this(actorId, msg);
this.hostname = hostname;
this.pid = pid;
}
public String getRollbackExceptionMsg() {
return exceptionMsg;
}
public String getHostname() {
return hostname;
}
public String getPid() {
return pid;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("fromActorId", fromActorId)
.toString();
}
}
@@ -1,19 +1,14 @@
package io.ray.streaming.runtime.master.graphmanager;
import io.ray.api.BaseActorHandle;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import java.util.HashMap;
import java.util.HashSet;
import io.ray.streaming.runtime.master.JobRuntimeContext;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -21,9 +16,9 @@ public class GraphManagerImpl implements GraphManager {
private static final Logger LOG = LoggerFactory.getLogger(GraphManagerImpl.class);
protected final JobMasterRuntimeContext runtimeContext;
protected final JobRuntimeContext runtimeContext;
public GraphManagerImpl(JobMasterRuntimeContext runtimeContext) {
public GraphManagerImpl(JobRuntimeContext runtimeContext) {
this.runtimeContext = runtimeContext;
}
@@ -53,7 +48,6 @@ public class GraphManagerImpl implements GraphManager {
// create vertex
Map<Integer, ExecutionJobVertex> exeJobVertexMap = new LinkedHashMap<>();
Map<Integer, ExecutionVertex> executionVertexMap = new HashMap<>();
long buildTime = executionGraph.getBuildTime();
for (JobVertex jobVertex : jobGraph.getJobVertices()) {
int jobVertexId = jobVertex.getVertexId();
@@ -65,47 +59,32 @@ public class GraphManagerImpl implements GraphManager {
buildTime));
}
// for each job edge, connect all source exeVertices and target exeVertices
// connect vertex
jobGraph.getJobEdges().forEach(jobEdge -> {
ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId());
ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId());
ExecutionJobEdge executionJobEdge = new ExecutionJobEdge(source, target, jobEdge);
ExecutionJobEdge executionJobEdge =
new ExecutionJobEdge(source, target, jobEdge);
source.getOutputEdges().add(executionJobEdge);
target.getInputEdges().add(executionJobEdge);
source.getExecutionVertices().forEach(sourceExeVertex -> {
target.getExecutionVertices().forEach(targetExeVertex -> {
// pre-process some mappings
executionVertexMap.put(targetExeVertex.getExecutionVertexId(), targetExeVertex);
executionVertexMap.put(sourceExeVertex.getExecutionVertexId(), sourceExeVertex);
// build execution edge
ExecutionEdge executionEdge =
new ExecutionEdge(sourceExeVertex, targetExeVertex, executionJobEdge);
sourceExeVertex.getOutputEdges().add(executionEdge);
targetExeVertex.getInputEdges().add(executionEdge);
source.getExecutionVertices().forEach(vertex -> {
target.getExecutionVertices().forEach(outputVertex -> {
ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge);
vertex.getOutputEdges().add(executionEdge);
outputVertex.getInputEdges().add(executionEdge);
});
});
});
// set execution job vertex into execution graph
executionGraph.setExecutionJobVertexMap(exeJobVertexMap);
executionGraph.setExecutionVertexMap(executionVertexMap);
return executionGraph;
}
private void addActorToChannelGroupedActors(
Map<String, Set<BaseActorHandle>> channelGroupedActors,
String channelId,
BaseActorHandle actor) {
Set<BaseActorHandle> actorSet =
channelGroupedActors.computeIfAbsent(channelId, k -> new HashSet<>());
actorSet.add(actor);
}
@Override
public JobGraph getJobGraph() {
return runtimeContext.getJobGraph();
@@ -11,7 +11,7 @@ import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.core.resource.Resources;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.JobRuntimeContext;
import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy;
import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategyFactory;
import io.ray.streaming.runtime.util.RayUtils;
@@ -30,33 +30,39 @@ public class ResourceManagerImpl implements ResourceManager {
//Container used tag
private static final String CONTAINER_ENGAGED_KEY = "CONTAINER_ENGAGED_KEY";
/**
* Resource description information.
*/
private final Resources resources;
/**
* Timing resource updating thread
*/
private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1,
new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build());
/**
* Job runtime context.
*/
private JobMasterRuntimeContext runtimeContext;
private JobRuntimeContext runtimeContext;
/**
* Resource related configuration.
*/
private ResourceConfig resourceConfig;
/**
* Slot assign strategy.
*/
private ResourceAssignStrategy resourceAssignStrategy;
/**
* Resource description information.
*/
private final Resources resources;
/**
* Customized actor number for each container
*/
private int actorNumPerContainer;
public ResourceManagerImpl(JobMasterRuntimeContext runtimeContext) {
/**
* Timing resource updating thread
*/
private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1,
new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build());
public ResourceManagerImpl(JobRuntimeContext runtimeContext) {
this.runtimeContext = runtimeContext;
StreamingMasterConfig masterConfig = runtimeContext.getConf().masterConfig;
@@ -23,18 +23,20 @@ import org.slf4j.LoggerFactory;
public class JobSchedulerImpl implements JobScheduler {
private static final Logger LOG = LoggerFactory.getLogger(JobSchedulerImpl.class);
private StreamingConfig jobConf;
private final JobMaster jobMaster;
private final ResourceManager resourceManager;
private final GraphManager graphManager;
private final WorkerLifecycleController workerLifecycleController;
private StreamingConfig jobConfig;
public JobSchedulerImpl(JobMaster jobMaster) {
this.jobMaster = jobMaster;
this.graphManager = jobMaster.getGraphManager();
this.resourceManager = jobMaster.getResourceManager();
this.workerLifecycleController = new WorkerLifecycleController();
this.jobConfig = jobMaster.getRuntimeContext().getConf();
this.jobConf = jobMaster.getRuntimeContext().getConf();
LOG.info("Scheduler initiated.");
}
@@ -44,13 +46,8 @@ public class JobSchedulerImpl implements JobScheduler {
LOG.info("Begin scheduling. Job: {}.", executionGraph.getJobName());
// Allocate resource then create workers
// Actor creation is in this step
prepareResourceAndCreateWorker(executionGraph);
// now actor info is available in execution graph
// preprocess some handy mappings in execution graph
executionGraph.generateActorMappings();
// init worker context and start to run
initAndStart(executionGraph);
@@ -90,7 +87,7 @@ public class JobSchedulerImpl implements JobScheduler {
initMaster();
// start workers
startWorkers(executionGraph, jobMaster.getRuntimeContext().lastCheckpointId);
startWorkers(executionGraph);
}
/**
@@ -125,7 +122,7 @@ public class JobSchedulerImpl implements JobScheduler {
boolean result;
try {
result = workerLifecycleController.initWorkers(vertexToContextMap,
jobConfig.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs());
jobConf.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs());
} catch (Exception e) {
LOG.error("Failed to initiate workers.", e);
return false;
@@ -136,12 +133,11 @@ public class JobSchedulerImpl implements JobScheduler {
/**
* Start JobWorkers according to the physical plan.
*/
public boolean startWorkers(ExecutionGraph executionGraph, long checkpointId) {
public boolean startWorkers(ExecutionGraph executionGraph) {
boolean result;
try {
result = workerLifecycleController.startWorkers(
executionGraph, checkpointId,
jobConfig.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs());
executionGraph, jobConf.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs());
} catch (Exception e) {
LOG.error("Failed to start workers.", e);
return false;
@@ -198,7 +194,7 @@ public class JobSchedulerImpl implements JobScheduler {
}
private void initMaster() {
jobMaster.init(false);
jobMaster.init();
}
}
@@ -9,8 +9,6 @@ import io.ray.api.id.ActorId;
import io.ray.streaming.api.Language;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.python.GraphPbBuilder;
import io.ray.streaming.runtime.rpc.RemoteCallWorker;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.JobWorkerContext;
@@ -42,23 +40,20 @@ public class WorkerLifecycleController {
* @return creation result
*/
private boolean createWorker(ExecutionVertex executionVertex) {
LOG.info("Start to create worker actor for vertex: {} with resource: {}, workeConfig: {}.",
executionVertex.getExecutionVertexName(), executionVertex.getResource(),
executionVertex.getWorkerConfig());
LOG.info("Start to create worker actor for vertex: {} with resource: {}.",
executionVertex.getExecutionVertexName(), executionVertex.getResource());
Language language = executionVertex.getLanguage();
BaseActorHandle actor;
if (Language.JAVA == language) {
actor = Ray.actor(JobWorker::new, executionVertex)
actor = Ray.actor(JobWorker::new)
.setResources(executionVertex.getResource())
.setMaxRestarts(-1)
.remote();
} else {
RemoteCall.ExecutionVertexContext.ExecutionVertex vertexPb
= new GraphPbBuilder().buildVertex(executionVertex);
actor = Ray.actor(
PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"), vertexPb.toByteArray())
PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"))
.setResources(executionVertex.getResource())
.setMaxRestarts(-1)
.remote();
@@ -116,20 +111,20 @@ public class WorkerLifecycleController {
* @param timeout timeout for waiting, unit: ms
* @return starting result
*/
public boolean startWorkers(ExecutionGraph executionGraph, long lastCheckpointId, int timeout) {
public boolean startWorkers(ExecutionGraph executionGraph, int timeout) {
LOG.info("Begin starting workers.");
long startTime = System.currentTimeMillis();
List<ObjectRef<Object>> objectRefs = new ArrayList<>();
List<ObjectRef<Boolean>> objectRefs = new ArrayList<>();
// start source actors 1st
executionGraph.getSourceActors()
.forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId)));
.forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor)));
// then start non-source actors
executionGraph.getNonSourceActors()
.forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId)));
.forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor)));
WaitResult<Object> result = Ray.wait(objectRefs, objectRefs.size(), timeout);
WaitResult<Boolean> result = Ray.wait(objectRefs, objectRefs.size(), timeout);
if (result.getReady().size() != objectRefs.size()) {
LOG.error("Starting workers timeout[{} ms].", timeout);
return false;
@@ -1,122 +0,0 @@
package io.ray.streaming.runtime.message;
import com.google.common.base.MoreObjects;
import java.io.Serializable;
public class CallResult<T> implements Serializable {
protected T resultObj;
private boolean success;
private int resultCode;
private String resultMsg;
public CallResult() {
}
public CallResult(boolean success, int resultCode, String resultMsg, T resultObj) {
this.success = success;
this.resultCode = resultCode;
this.resultMsg = resultMsg;
this.resultObj = resultObj;
}
public static <T> CallResult<T> success() {
return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, null);
}
public static <T> CallResult<T> success(T payload) {
return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, payload);
}
public static <T> CallResult<T> skipped(String msg) {
return new CallResult<>(true, CallResultEnum.SKIPPED.code, msg, null);
}
public static <T> CallResult<T> fail() {
return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, null);
}
public static <T> CallResult<T> fail(T payload) {
return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, payload);
}
public static <T> CallResult<T> fail(String msg) {
return new CallResult<>(false, CallResultEnum.FAILED.code, msg, null);
}
public static <T> CallResult<T> fail(CallResultEnum resultEnum, T payload) {
return new CallResult<>(false, resultEnum.code, resultEnum.msg, payload);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("resultObj", resultObj)
.add("success", success)
.add("resultCode", resultCode)
.add("resultMsg", resultMsg)
.toString();
}
public boolean isSuccess() {
return this.success;
}
public void setSuccess(boolean success) {
this.success = success;
}
public int getResultCode() {
return this.resultCode;
}
public void setResultCode(int resultCode) {
this.resultCode = resultCode;
}
public CallResultEnum getResultEnum() {
return CallResultEnum.getEnum(this.resultCode);
}
public String getResultMsg() {
return this.resultMsg;
}
public void setResultMsg(String resultMsg) {
this.resultMsg = resultMsg;
}
public T getResultObj() {
return this.resultObj;
}
public void setResultObj(T resultObj) {
this.resultObj = resultObj;
}
public enum CallResultEnum implements Serializable {
/**
* call result enum
*/
SUCCESS(0, "SUCCESS"),
FAILED(1, "FAILED"),
SKIPPED(2, "SKIPPED");
public final int code;
public final String msg;
CallResultEnum(int code, String msg) {
this.code = code;
this.msg = msg;
}
public static CallResultEnum getEnum(int code) {
for (CallResultEnum value : CallResultEnum.values()) {
if (code == value.code) {
return value;
}
}
return FAILED;
}
}
}
@@ -65,7 +65,7 @@ public class GraphPbBuilder {
return builder.build();
}
public RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex(
private RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex(
ExecutionVertex executionVertex) {
// build vertex infos
RemoteCall.ExecutionVertexContext.ExecutionVertex.Builder executionVertexBuilder =
@@ -79,11 +79,9 @@ public class GraphPbBuilder {
ByteString.copyFrom(
serializeOperator(executionVertex.getStreamOperator())));
executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator()));
if (executionVertex.getWorkerActor() != null) {
executionVertexBuilder.setWorkerActor(
ByteString.copyFrom(
((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes()));
}
executionVertexBuilder.setWorkerActor(
ByteString.copyFrom(
((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes()));
executionVertexBuilder.setContainerId(executionVertex.getContainerId().toString());
executionVertexBuilder.setBuildTime(executionVertex.getBuildTime());
executionVertexBuilder.setLanguage(
@@ -1,54 +0,0 @@
package io.ray.streaming.runtime.rpc;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.message.CallResult;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class PbResultParser {
private static final Logger LOG = LoggerFactory.getLogger(PbResultParser.class);
public static Boolean parseBoolResult(byte[] result) {
if (null == result) {
LOG.warn("Result is null.");
return false;
}
RemoteCall.BoolResult boolResult;
try {
boolResult = RemoteCall.BoolResult.parseFrom(result);
} catch (InvalidProtocolBufferException e) {
LOG.error("Parse boolean result has exception.", e);
return false;
}
return boolResult.getBoolRes();
}
public static CallResult<ChannelRecoverInfo> parseRollbackResult(byte[] bytes) {
RemoteCall.CallResult callResultPb;
try {
callResultPb = RemoteCall.CallResult.parseFrom(bytes);
} catch (InvalidProtocolBufferException e) {
LOG.error("Rollback parse result has exception.", e);
return CallResult.fail();
}
CallResult<ChannelRecoverInfo> callResult = new CallResult<>();
callResult.setSuccess(callResultPb.getSuccess());
callResult.setResultCode(callResultPb.getResultCode());
callResult.setResultMsg(callResultPb.getResultMsg());
RemoteCall.QueueRecoverInfo recoverInfo = callResultPb.getResultObj();
Map<String, ChannelRecoverInfo.ChannelCreationStatus> creationStatusMap = new HashMap<>();
recoverInfo.getCreationStatusMap().forEach((k, v) -> {
creationStatusMap.put(k, ChannelRecoverInfo.ChannelCreationStatus.fromInt(v.getNumber()));
});
callResult.setResultObj(new ChannelRecoverInfo(creationStatusMap));
return callResult;
}
}
@@ -1,46 +0,0 @@
package io.ray.streaming.runtime.rpc;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
public class RemoteCallMaster {
public static ObjectRef<byte[]> reportJobWorkerCommitAsync(
ActorHandle<JobMaster> actor,
WorkerCommitReport commitReport) {
RemoteCall.WorkerCommitReport commit = RemoteCall.WorkerCommitReport.newBuilder()
.setCommitCheckpointId(commitReport.commitCheckpointId)
.build();
Any detail = Any.pack(commit);
RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder()
.setActorId(ByteString.copyFrom(commitReport.fromActorId.getBytes()))
.setTimestamp(System.currentTimeMillis())
.setDetail(detail).build();
return actor.task(JobMaster::reportJobWorkerCommit, cmd.toByteArray()).remote();
}
public static Boolean requestJobWorkerRollback(
ActorHandle<JobMaster> actor,
WorkerRollbackRequest rollbackRequest) {
RemoteCall.WorkerRollbackRequest request = RemoteCall.WorkerRollbackRequest.newBuilder()
.setExceptionMsg(rollbackRequest.getRollbackExceptionMsg())
.setWorkerHostname(rollbackRequest.getHostname())
.setWorkerPid(rollbackRequest.getPid()).build();
Any detail = Any.pack(request);
RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder()
.setActorId(ByteString.copyFrom(rollbackRequest.fromActorId.getBytes()))
.setTimestamp(System.currentTimeMillis())
.setDetail(detail).build();
ObjectRef<byte[]> ret = actor.task(
JobMaster::requestJobWorkerRollback, cmd.toByteArray()).remote();
byte[] res = ret.get();
return PbResultParser.parseBoolResult(res);
}
}
@@ -4,15 +4,10 @@ import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.RayFunc3;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.JobWorkerContext;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -51,26 +46,19 @@ public class RemoteCallWorker {
* Call JobWorker actor to start.
*
* @param actor target JobWorker actor
* @param checkpointId checkpoint ID to be rollback
* @return start result
*/
public static ObjectRef rollback(BaseActorHandle actor, final Long checkpointId) {
public static ObjectRef<Boolean> startWorker(BaseActorHandle actor) {
LOG.info("Call worker to start, actor: {}.", actor.getId());
ObjectRef result;
ObjectRef<Boolean> result = null;
// python
if (actor instanceof PyActorHandle) {
RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(checkpointId)
.build();
result = ((PyActorHandle) actor)
.task(PyActorMethod.of("rollback"),
checkpointIdPb.toByteArray()
).remote();
.task(PyActorMethod.of("start", Boolean.class)).remote();
} else {
// java
result = ((ActorHandle<JobWorker>) actor)
.task(JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote();
result = ((ActorHandle<JobWorker>) actor).task(JobWorker::start).remote();
}
LOG.info("Finished calling worker to start.");
@@ -94,92 +82,4 @@ public class RemoteCallWorker {
return result;
}
public static ObjectRef triggerCheckpoint(BaseActorHandle actor, Long barrierId) {
// python
if (actor instanceof PyActorHandle) {
RemoteCall.Barrier barrierPb = RemoteCall.Barrier.newBuilder().setId(barrierId).build();
return ((PyActorHandle) actor).task(
PyActorMethod.of("commit"), barrierPb.toByteArray()).remote();
} else {
// java
return ((ActorHandle<JobWorker>) actor).task(JobWorker::triggerCheckpoint, barrierId)
.remote();
}
}
public static void clearExpiredCheckpointParallel(
List<BaseActorHandle> actors, Long stateCheckpointId,
Long queueCheckpointId) {
if (LOG.isInfoEnabled()) {
LOG.info("Call worker clearExpiredCheckpoint, state checkpoint id is {}," +
" queue checkpoint id is {}.", stateCheckpointId, queueCheckpointId);
}
List<Object> result =
checkpointCompleteCommonCallTwoWay(actors, stateCheckpointId, queueCheckpointId,
"clear_expired_cp", JobWorker::clearExpiredCheckpoint);
if (LOG.isInfoEnabled()) {
result.forEach(
obj -> LOG.info("Finish call worker clearExpiredCheckpointParallel, ret is {}.", obj));
}
}
public static void notifyCheckpointTimeoutParallel(
List<BaseActorHandle> actors,
Long checkpointId) {
LOG.info("Call worker notifyCheckpointTimeoutParallel, checkpoint id is {}", checkpointId);
actors.forEach(actor -> {
if (actor instanceof PyActorHandle) {
RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(checkpointId)
.build();
((PyActorHandle) actor).task(PyActorMethod.of("notify_checkpoint_timeout"),
checkpointIdPb.toByteArray()).remote();
} else {
((ActorHandle<JobWorker>) actor).task(JobWorker::notifyCheckpointTimeout, checkpointId)
.remote();
}
});
LOG.info("Finish call worker notifyCheckpointTimeoutParallel.");
}
private static List<Object> checkpointCompleteCommonCallTwoWay(
List<BaseActorHandle> actors, Long stateCheckpointId, Long queueCheckpointId,
String pyFuncName, RayFunc3<JobWorker, Long, Long, Boolean> rayFunc) {
List<ObjectRef<Object>> waitFor =
checkpointCompleteCommonCall(actors, stateCheckpointId, queueCheckpointId,
pyFuncName, rayFunc);
return Ray.get(waitFor);
}
private static List<ObjectRef<Object>> checkpointCompleteCommonCall(
List<BaseActorHandle> actors,
Long stateCheckpointId, Long queueCheckpointId,
String pyFuncName,
RayFunc3<JobWorker, Long, Long, Boolean> rayFunc) {
List<ObjectRef<Object>> waitFor = new ArrayList<>();
actors.forEach(actor -> {
// python
if (actor instanceof PyActorHandle) {
RemoteCall.CheckpointId stateCheckpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(stateCheckpointId)
.build();
RemoteCall.CheckpointId queueCheckpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(queueCheckpointId)
.build();
waitFor.add(((PyActorHandle) actor).task(PyActorMethod.of(pyFuncName),
stateCheckpointIdPb.toByteArray(), queueCheckpointIdPb.toByteArray()).remote());
} else {
// java
waitFor.add(((ActorHandle) actor).task(rayFunc, stateCheckpointId, queueCheckpointId)
.remote());
}
});
return waitFor;
}
}
@@ -1,131 +0,0 @@
package io.ray.streaming.runtime.rpc.async;
import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.function.PyActorMethod;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.message.CallResult;
import io.ray.streaming.runtime.rpc.PbResultParser;
import io.ray.streaming.runtime.rpc.async.RemoteCallPool.Callback;
import io.ray.streaming.runtime.rpc.async.RemoteCallPool.ExceptionHandler;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@SuppressWarnings("unchecked")
public class AsyncRemoteCaller {
private static final Logger LOG = LoggerFactory.getLogger(AsyncRemoteCaller.class);
private RemoteCallPool remoteCallPool = new RemoteCallPool();
/**
* Call JobWorker::checkIfNeedRollback async
*
* @param actor JobWorker actor
* @param callback callback function on success
* @param onException callback function on exception
*/
public void checkIfNeedRollbackAsync(
BaseActorHandle actor, Callback<Boolean> callback,
ExceptionHandler<Throwable> onException) {
if (actor instanceof PyActorHandle) {
// python
remoteCallPool.bindCallback(
((PyActorHandle) actor).task(PyActorMethod.of("check_if_need_rollback")).remote(),
(obj) -> {
byte[] res = (byte[]) obj;
callback.handle(PbResultParser.parseBoolResult(res));
}, onException);
} else {
// java
remoteCallPool.bindCallback(
((ActorHandle<JobWorker>) actor).task(JobWorker::checkIfNeedRollback,
System.currentTimeMillis()).remote(), callback, onException);
}
}
/**
* Call JobWorker::rollback async
*
* @param actor JobWorker actor
* @param callback callback function on success
* @param onException callback function on exception
*/
public void rollback(
BaseActorHandle actor,
final Long checkpointId,
Callback<CallResult<ChannelRecoverInfo>> callback,
ExceptionHandler<Throwable> onException) {
// python
if (actor instanceof PyActorHandle) {
RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(checkpointId)
.build();
ObjectRef call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"),
checkpointIdPb.toByteArray()).remote();
remoteCallPool.bindCallback(call, obj ->
callback.handle(PbResultParser.parseRollbackResult((byte[]) obj)), onException);
} else {
// java
ObjectRef call = ((ActorHandle<JobWorker>) actor).task(
JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote();
remoteCallPool.bindCallback(call, obj -> {
CallResult<ChannelRecoverInfo> res = (CallResult<ChannelRecoverInfo>) obj;
callback.handle(res);
}, onException);
}
}
/**
* Call JobWorker::rollback async in batch
*
* @param actors JobWorker actor list
* @param callback callback function on success
* @param onException callback function on exception
*/
public void batchRollback(
List<BaseActorHandle> actors, final Long checkpointId,
Collection<String> abnormalQueues,
Callback<List<CallResult<ChannelRecoverInfo>>> callback,
ExceptionHandler<Throwable> onException) {
List<ObjectRef<Object>> rayCallList = new ArrayList<>();
Map<Integer, Boolean> isPyActor = new HashMap<>();
for (int i = 0; i < actors.size(); ++i) {
BaseActorHandle actor = actors.get(i);
ObjectRef call;
if (actor instanceof PyActorHandle) {
isPyActor.put(i, true);
RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder()
.setCheckpointId(checkpointId)
.build();
call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"),
checkpointIdPb.toByteArray()).remote();
} else {
// java
call = ((ActorHandle<JobWorker>) actor).task(JobWorker::rollback, checkpointId,
System.currentTimeMillis()).remote();
}
rayCallList.add(call);
}
remoteCallPool.bindCallback(rayCallList, objList -> {
List<CallResult<ChannelRecoverInfo>> results = new ArrayList<>();
for (int i = 0; i < objList.size(); ++i) {
Object obj = objList.get(i);
if (isPyActor.getOrDefault(i, false)) {
results.add(PbResultParser.parseRollbackResult((byte[]) obj));
} else {
results.add((CallResult<ChannelRecoverInfo>) obj);
}
}
callback.handle(results);
}, onException);
}
}
@@ -1,189 +0,0 @@
package io.ray.streaming.runtime.rpc.async;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.WaitResult;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RemoteCallPool implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(RemoteCallPool.class);
private static final int WAIT_TIME_MS = 5;
private static final long WARNING_PERIOD = 10000;
private final List<RemoteCallBundle> pendingObjectBundles = new LinkedList<>();
private Map<RemoteCallBundle, Callback<Object>> singletonHandlerMap = new ConcurrentHashMap<>();
private Map<RemoteCallBundle, Callback<List<Object>>> bundleHandlerMap =
new ConcurrentHashMap<>();
private Map<RemoteCallBundle, ExceptionHandler<Throwable>> bundleExceptionHandlerMap =
new ConcurrentHashMap<>();
private ThreadPoolExecutor callBackPool = new ThreadPoolExecutor(
2, Runtime.getRuntime().availableProcessors(),
1, TimeUnit.MINUTES, new LinkedBlockingQueue<>(),
new CallbackThreadFactory());
private volatile boolean stop = false;
public RemoteCallPool() {
Thread t = new Thread(Ray.wrapRunnable(this), "remote-pool-loop");
t.setUncaughtExceptionHandler((thread, throwable) ->
LOG.error("Error in remote call pool thread.", throwable)
);
t.start();
}
@SuppressWarnings("unchecked")
public <T> void bindCallback(
ObjectRef<T> obj, Callback<T> callback,
ExceptionHandler<Throwable> onException) {
List objectRefList = Collections.singletonList(obj);
RemoteCallBundle bundle = new RemoteCallBundle(objectRefList,
true);
singletonHandlerMap.put(bundle, (Callback<Object>) callback);
bundleExceptionHandlerMap.put(bundle, onException);
synchronized (pendingObjectBundles) {
pendingObjectBundles.add(bundle);
}
}
public void bindCallback(
List<ObjectRef<Object>> objectBundle, Callback<List<Object>> callback,
ExceptionHandler<Throwable> onException) {
RemoteCallBundle bundle = new RemoteCallBundle(objectBundle, false);
bundleHandlerMap.put(bundle, callback);
bundleExceptionHandlerMap.put(bundle, onException);
synchronized (pendingObjectBundles) {
pendingObjectBundles.add(bundle);
}
}
public void stop() {
stop = true;
}
public void run() {
while (!stop) {
try {
if (pendingObjectBundles.isEmpty()) {
Thread.sleep(WAIT_TIME_MS);
continue;
}
synchronized (pendingObjectBundles) {
Iterator<RemoteCallBundle> itr = pendingObjectBundles.iterator();
while (itr.hasNext()) {
RemoteCallBundle bundle = itr.next();
WaitResult<Object> waitResult =
Ray.wait(bundle.objects, bundle.objects.size(), WAIT_TIME_MS);
List<ObjectRef<Object>> readyObjs = waitResult.getReady();
if (readyObjs.size() != bundle.objects.size()) {
long now = System.currentTimeMillis();
long waitingTime = now - bundle.createTime;
if (waitingTime > WARNING_PERIOD && now - bundle.lastWarnTs > WARNING_PERIOD) {
bundle.lastWarnTs = now;
LOG.warn("Bundle has being waiting for {} ms, bundle = {}.", waitingTime, bundle);
}
continue;
}
ExceptionHandler<Throwable> exceptionHandler = bundleExceptionHandlerMap.get(bundle);
if (bundle.isSingletonBundle) {
callBackPool.execute(Ray.wrapRunnable(() -> {
try {
singletonHandlerMap.get(bundle).handle(readyObjs.get(0).get());
singletonHandlerMap.remove(bundle);
} catch (Throwable th) {
LOG.error("Error when get object, objectId = {}.", readyObjs.get(0).toString(),
th);
if (exceptionHandler != null) {
exceptionHandler.handle(th);
}
}
}));
} else {
List<Object> results =
readyObjs.stream().map(ObjectRef::get).collect(Collectors.toList());
List<String> resultIds =
readyObjs.stream().map(ObjectRef::toString).collect(Collectors.toList());
callBackPool.execute(Ray.wrapRunnable(() -> {
try {
bundleHandlerMap.get(bundle).handle(results);
bundleHandlerMap.remove(bundle);
} catch (Throwable th) {
LOG.error("Error when get object, objectIds = {}.", resultIds, th);
if (exceptionHandler != null) {
exceptionHandler.handle(th);
}
}
}));
}
itr.remove();
}
}
} catch (Exception e) {
LOG.error("Exception in wait loop.", e);
}
}
LOG.info("Wait loop finished.");
}
@FunctionalInterface
public interface ExceptionHandler<T> {
void handle(T object);
}
@FunctionalInterface
public interface Callback<T> {
void handle(T object) throws Throwable;
}
private static class RemoteCallBundle {
List<ObjectRef<Object>> objects;
boolean isSingletonBundle;
long lastWarnTs = System.currentTimeMillis();
long createTime = System.currentTimeMillis();
RemoteCallBundle(List<ObjectRef<Object>> objects, boolean isSingletonBundle) {
this.objects = objects;
this.isSingletonBundle = isSingletonBundle;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
objects.forEach(rayObj -> sb.append(rayObj.toString()).append(","));
sb.append("]");
return sb.toString();
}
}
static class CallbackThreadFactory implements ThreadFactory {
private AtomicInteger cnt = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setUncaughtExceptionHandler((thread, throwable) -> LOG.error("Callback err.", throwable));
t.setName("callback-thread-" + cnt.getAndIncrement());
return t;
}
}
}
@@ -1,4 +1,4 @@
package io.ray.streaming.runtime.transfer.channel;
package io.ray.streaming.runtime.transfer;
import com.google.common.base.FinalizablePhantomReference;
import com.google.common.base.FinalizableReferenceQueue;
@@ -41,6 +41,47 @@ public class ChannelId {
this.nativeIdPtr = nativeIdPtr;
}
public byte[] getBytes() {
return bytes;
}
public ByteBuffer getBuffer() {
return buffer;
}
public long getAddress() {
return address;
}
public long getNativeIdPtr() {
if (nativeIdPtr == 0) {
throw new IllegalStateException("native ID not available");
}
return nativeIdPtr;
}
@Override
public String toString() {
return strId;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ChannelId that = (ChannelId) o;
return strId.equals(that.strId);
}
@Override
public int hashCode() {
return strId.hashCode();
}
private static native long createNativeId(long idAddress);
private static native void destroyNativeId(long nativeIdPtr);
@@ -123,7 +164,7 @@ public class ChannelId {
* @param id hex string representation of channel id
* @return bytes representation of channel id
*/
public static byte[] idStrToBytes(String id) {
static byte[] idStrToBytes(String id) {
byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase());
assert idBytes.length == ChannelId.ID_LENGTH;
return idBytes;
@@ -133,51 +174,10 @@ public class ChannelId {
* @param id bytes representation of channel id
* @return hex string representation of channel id
*/
public static String idBytesToStr(byte[] id) {
static String idBytesToStr(byte[] id) {
assert id.length == ChannelId.ID_LENGTH;
return BaseEncoding.base16().encode(id).toLowerCase();
}
public byte[] getBytes() {
return bytes;
}
public ByteBuffer getBuffer() {
return buffer;
}
public long getAddress() {
return address;
}
public long getNativeIdPtr() {
if (nativeIdPtr == 0) {
throw new IllegalStateException("native ID not available");
}
return nativeIdPtr;
}
@Override
public String toString() {
return strId;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ChannelId that = (ChannelId) o;
return strId.equals(that.strId);
}
@Override
public int hashCode() {
return strId.hashCode();
}
}
@@ -1,4 +1,4 @@
package io.ray.streaming.runtime.transfer.channel;
package io.ray.streaming.runtime.transfer;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.generated.Streaming;
@@ -10,7 +10,7 @@ public class ChannelUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class);
public static byte[] toNativeConf(StreamingWorkerConfig workerConfig) {
static byte[] toNativeConf(StreamingWorkerConfig workerConfig) {
Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder();
// job name
@@ -0,0 +1,55 @@
package io.ray.streaming.runtime.transfer;
import java.nio.ByteBuffer;
/**
* DataMessage represents data between upstream and downstream operator
*/
public class DataMessage implements Message {
private final ByteBuffer body;
private final long msgId;
private final long timestamp;
private final String channelId;
public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) {
this.body = body;
this.timestamp = timestamp;
this.msgId = msgId;
this.channelId = channelId;
}
@Override
public ByteBuffer body() {
return body;
}
@Override
public long timestamp() {
return timestamp;
}
/**
* @return message id
*/
public long msgId() {
return msgId;
}
/**
* @return string id of channel where data is coming from
*/
public String channelId() {
return channelId;
}
@Override
public String toString() {
return "DataMessage{" +
"body=" + body +
", msgId=" + msgId +
", timestamp=" + timestamp +
", channelId='" + channelId + '\'' +
'}';
}
}
@@ -4,22 +4,11 @@ import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.config.types.TransferChannelType;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus;
import io.ray.streaming.runtime.transfer.channel.ChannelUtils;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import io.ray.streaming.runtime.transfer.message.BarrierMessage;
import io.ray.streaming.runtime.transfer.message.ChannelMessage;
import io.ray.streaming.runtime.transfer.message.DataMessage;
import io.ray.streaming.runtime.util.Platform;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -33,20 +22,7 @@ public class DataReader {
private static final Logger LOG = LoggerFactory.getLogger(DataReader.class);
private long nativeReaderPtr;
// params set by getBundleNative: bundle data address + size
private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24);
// We use direct buffer to reduce gc overhead and memory copy.
private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0);
private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH);
private final Map<String, ChannelCreationStatus> queueCreationStatusMap = new HashMap<>();
private Queue<ChannelMessage> buf = new LinkedList<>();
{
getBundleParams.order(ByteOrder.nativeOrder());
bundleData.order(ByteOrder.nativeOrder());
bundleMeta.order(ByteOrder.nativeOrder());
}
private Queue<DataMessage> buf = new LinkedList<>();
/**
* @param inputChannels input channels ids
@@ -56,7 +32,6 @@ public class DataReader {
public DataReader(
List<String> inputChannels,
List<BaseActorHandle> fromActors,
Map<String, OffsetInfo> checkpoints,
StreamingWorkerConfig workerConfig) {
Preconditions.checkArgument(inputChannels.size() > 0);
Preconditions.checkArgument(inputChannels.size() == fromActors.size());
@@ -64,16 +39,11 @@ public class DataReader {
new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors);
byte[][] inputChannelsBytes = inputChannels.stream()
.map(ChannelId::idStrToBytes).toArray(byte[][]::new);
// get sequence ID and message ID from OffsetInfo
long[] seqIds = new long[inputChannels.size()];
long[] msgIds = new long[inputChannels.size()];
for (int i = 0; i < inputChannels.size(); i++) {
String channelId = inputChannels.get(i);
if (!checkpoints.containsKey(channelId)) {
msgIds[i] = 0;
continue;
}
msgIds[i] = checkpoints.get(inputChannels.get(i)).getStreamingMsgId();
seqIds[i] = 0;
msgIds[i] = 0;
}
long timerInterval = workerConfig.transferConfig.readerTimerIntervalMs();
TransferChannelType channelType = workerConfig.transferConfig.channelType();
@@ -81,34 +51,33 @@ public class DataReader {
if (TransferChannelType.MEMORY_CHANNEL == channelType) {
isMock = true;
}
boolean isRecreate = workerConfig.transferConfig.readerIsRecreate();
// create native reader
List<Integer> creationStatus = new ArrayList<>();
this.nativeReaderPtr = createDataReaderNative(
initialParameters,
inputChannelsBytes,
seqIds,
msgIds,
timerInterval,
creationStatus,
isRecreate,
ChannelUtils.toNativeConf(workerConfig),
isMock
);
for (int i = 0; i < inputChannels.size(); ++i) {
queueCreationStatusMap
.put(inputChannels.get(i), ChannelCreationStatus.fromInt(creationStatus.get(i)));
}
LOG.info("Create DataReader succeed for worker: {}, creation status={}.",
workerConfig.workerInternalConfig.workerName(), queueCreationStatusMap);
LOG.info("Create DataReader succeed for worker: {}.",
workerConfig.workerInternalConfig.workerName());
}
private static native long createDataReaderNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] inputChannels,
long[] msgIds,
long timerInterval,
List<Integer> creationStatus,
byte[] configBytes,
boolean isMock);
// params set by getBundleNative: bundle data address + size
private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24);
// We use direct buffer to reduce gc overhead and memory copy.
private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0);
private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH);
{
getBundleParams.order(ByteOrder.nativeOrder());
bundleData.order(ByteOrder.nativeOrder());
bundleMeta.order(ByteOrder.nativeOrder());
}
/**
* Read message from input channels, if timeout, return null.
@@ -116,21 +85,26 @@ public class DataReader {
* @param timeoutMillis timeout
* @return message or null
*/
public ChannelMessage read(long timeoutMillis) {
public DataMessage read(long timeoutMillis) {
if (buf.isEmpty()) {
getBundle(timeoutMillis);
// if bundle not empty. empty message still has data size + seqId + msgId
if (bundleData.position() < bundleData.limit()) {
BundleMeta bundleMeta = new BundleMeta(this.bundleMeta);
String channelID = bundleMeta.getChannelID();
long timestamp = bundleMeta.getBundleTs();
// barrier
if (bundleMeta.getBundleType() == DataBundleType.BARRIER) {
buf.offer(getBarrier(bundleData, channelID, timestamp));
throw new UnsupportedOperationException(
"Unsupported bundle type " + bundleMeta.getBundleType());
} else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) {
String channelID = bundleMeta.getChannelID();
long timestamp = bundleMeta.getBundleTs();
for (int i = 0; i < bundleMeta.getMessageListSize(); i++) {
buf.offer(getDataMessage(bundleData, channelID, timestamp));
}
} else if (bundleMeta.getBundleType() == DataBundleType.EMPTY) {
long messageId = bundleMeta.getLastMessageId();
buf.offer(new DataMessage(null, bundleMeta.getBundleTs(),
messageId, bundleMeta.getChannelID()));
}
}
}
@@ -140,31 +114,6 @@ public class DataReader {
return buf.poll();
}
public ChannelRecoverInfo getQueueRecoverInfo() {
return new ChannelRecoverInfo(queueCreationStatusMap);
}
private String getQueueIdString(ByteBuffer buffer) {
byte[] bytes = new byte[ChannelId.ID_LENGTH];
buffer.get(bytes);
return ChannelId.idBytesToStr(bytes);
}
private BarrierMessage getBarrier(ByteBuffer bundleData, String channelID, long timestamp) {
ByteBuffer offsetsInfoBytes = ByteBuffer.wrap(getOffsetsInfoNative(nativeReaderPtr));
offsetsInfoBytes.order(ByteOrder.nativeOrder());
BarrierOffsetInfo offsetInfo = new BarrierOffsetInfo(offsetsInfoBytes);
DataMessage message = getDataMessage(bundleData, channelID, timestamp);
BarrierItem barrierItem = new BarrierItem(message, offsetInfo);
return new BarrierMessage(
message.getMsgId(),
message.getTimestamp(),
message.getChannelId(),
barrierItem.getData(),
barrierItem.getGlobalBarrierId(),
barrierItem.getBarrierOffsetInfo().getQueueOffsetInfo());
}
private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) {
int dataSize = bundleData.getInt();
// msgId
@@ -212,14 +161,22 @@ public class DataReader {
LOG.info("Finish closing DataReader.");
}
private static native long createDataReaderNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] inputChannels,
long[] seqIds,
long[] msgIds,
long timerInterval,
boolean isRecreate,
byte[] configBytes,
boolean isMock);
private native void getBundleNative(
long nativeReaderPtr,
long timeoutMillis,
long params,
long metaAddress);
private native byte[] getOffsetsInfoNative(long nativeQueueConsumerPtr);
private native void stopReaderNative(long nativeReaderPtr);
private native void closeReaderNative(long nativeReaderPtr);
@@ -236,16 +193,7 @@ public class DataReader {
}
}
public enum BarrierType {
GLOBAL_BARRIER(0);
private int code;
BarrierType(int code) {
this.code = code;
}
}
class BundleMeta {
static class BundleMeta {
// kMessageBundleHeaderSize + kUniqueIDSize:
// magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b)
@@ -278,7 +226,13 @@ public class DataReader {
}
// rawBundleSize
rawBundleSize = buffer.getInt();
channelID = getQueueIdString(buffer);
channelID = getQidString(buffer);
}
private String getQidString(ByteBuffer buffer) {
byte[] bytes = new byte[ChannelId.ID_LENGTH];
buffer.get(bytes);
return ChannelId.idBytesToStr(bytes);
}
public int getMagicNum() {
@@ -310,73 +264,4 @@ public class DataReader {
}
}
class BarrierOffsetInfo {
private int queueSize;
private Map<String, OffsetInfo> queueOffsetInfo;
public BarrierOffsetInfo(ByteBuffer buffer) {
// deserialization offset
queueSize = buffer.getInt();
queueOffsetInfo = new HashMap<>(queueSize);
for (int i = 0; i < queueSize; ++i) {
String qid = getQueueIdString(buffer);
long streamingMsgId = buffer.getLong();
queueOffsetInfo.put(qid, new OffsetInfo(streamingMsgId));
}
}
public int getQueueSize() {
return queueSize;
}
public Map<String, OffsetInfo> getQueueOffsetInfo() {
return queueOffsetInfo;
}
}
class BarrierItem {
BarrierOffsetInfo barrierOffsetInfo;
private long msgId;
private BarrierType barrierType;
private long globalBarrierId;
private ByteBuffer data;
public BarrierItem(DataMessage message, BarrierOffsetInfo barrierOffsetInfo) {
this.barrierOffsetInfo = barrierOffsetInfo;
msgId = message.getMsgId();
ByteBuffer buffer = message.body();
// c++ use native order, so use native order here.
buffer.order(ByteOrder.nativeOrder());
int barrierTypeInt = buffer.getInt();
globalBarrierId = buffer.getLong();
// dataSize includes: barrier type(32 bit), globalBarrierId, data
data = buffer.slice();
data.order(ByteOrder.nativeOrder());
buffer.position(buffer.limit());
barrierType = BarrierType.GLOBAL_BARRIER;
}
public long getBarrierMsgId() {
return msgId;
}
public BarrierType getBarrierType() {
return barrierType;
}
public long getGlobalBarrierId() {
return globalBarrierId;
}
public ByteBuffer getData() {
return data;
}
public BarrierOffsetInfo getBarrierOffsetInfo() {
return barrierOffsetInfo;
}
}
}
@@ -4,15 +4,10 @@ import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.config.types.TransferChannelType;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.transfer.channel.ChannelUtils;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import io.ray.streaming.runtime.util.Platform;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -27,7 +22,6 @@ public class DataWriter {
private long nativeWriterPtr;
private ByteBuffer buffer = ByteBuffer.allocateDirect(0);
private long bufferAddress;
private List<String> outputChannels;
{
ensureBuffer(0);
@@ -37,33 +31,21 @@ public class DataWriter {
* @param outputChannels output channels ids
* @param toActors downstream output actors
* @param workerConfig configuration
* @param checkpoints offset of each channels
*/
public DataWriter(
List<String> outputChannels,
List<BaseActorHandle> toActors,
Map<String, OffsetInfo> checkpoints,
StreamingWorkerConfig workerConfig) {
Preconditions.checkArgument(!outputChannels.isEmpty());
Preconditions.checkArgument(outputChannels.size() == toActors.size());
this.outputChannels = outputChannels;
ChannelCreationParametersBuilder initialParameters =
new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors);
byte[][] outputChannelsBytes = outputChannels.stream()
.map(ChannelId::idStrToBytes).toArray(byte[][]::new);
long channelSize = workerConfig.transferConfig.channelSize();
// load message id from checkpoints
long[] msgIds = new long[outputChannels.size()];
for (int i = 0; i < outputChannels.size(); i++) {
String channelId = outputChannels.get(i);
if (!checkpoints.containsKey(channelId)) {
msgIds[i] = 0;
continue;
}
msgIds[i] = checkpoints.get(channelId).getStreamingMsgId();
msgIds[i] = 0;
}
TransferChannelType channelType = workerConfig.transferConfig.channelType();
boolean isMock = false;
@@ -82,14 +64,6 @@ public class DataWriter {
workerConfig.workerInternalConfig.workerName());
}
private static native long createWriterNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] outputQueueIds,
long[] msgIds,
long channelSize,
byte[] confBytes,
boolean isMock);
/**
* Write msg into the specified channel
*
@@ -108,8 +82,9 @@ public class DataWriter {
* Write msg into the specified channels
*
* @param ids channel ids
* @param item message item data section is specified by [position, limit).
* item doesn't have to be a direct buffer.
* @param item message item data section is specified by [position, limit). item doesn't have
* to
* be a direct buffer.
*/
public void write(Set<ChannelId> ids, ByteBuffer item) {
int size = item.remaining();
@@ -129,27 +104,6 @@ public class DataWriter {
}
}
public Map<String, OffsetInfo> getOutputCheckpoints() {
long[] msgId = getOutputMsgIdNative(nativeWriterPtr);
Map<String, OffsetInfo> res = new HashMap<>(outputChannels.size());
for (int i = 0; i < outputChannels.size(); ++i) {
res.put(outputChannels.get(i), new OffsetInfo(msgId[i]));
}
LOG.info("got output points, {}.", res);
return res;
}
public void broadcastBarrier(long checkpointId, ByteBuffer attach) {
LOG.info("Broadcast barrier, cpId={}.", checkpointId);
Preconditions.checkArgument(attach.order() == ByteOrder.nativeOrder());
broadcastBarrierNative(nativeWriterPtr, checkpointId, attach.array());
}
public void clearCheckpoint(long checkpointId) {
LOG.info("Producer clear checkpoint, checkpointId={}.", checkpointId);
clearCheckpointNative(nativeWriterPtr, checkpointId);
}
/**
* stop writer
*/
@@ -170,6 +124,14 @@ public class DataWriter {
LOG.info("Finish closing data writer.");
}
private static native long createWriterNative(
ChannelCreationParametersBuilder initialParameters,
byte[][] outputQueueIds,
long[] msgIds,
long channelSize,
byte[] confBytes,
boolean isMock);
private native long writeMessageNative(
long nativeQueueProducerPtr, long nativeIdPtr, long address, int size);
@@ -177,15 +139,4 @@ public class DataWriter {
private native void closeWriterNative(long nativeQueueProducerPtr);
private native long[] getOutputMsgIdNative(long nativeQueueProducerPtr);
private native void broadcastBarrierNative(
long nativeQueueProducerPtr, long checkpointId,
byte[] data);
private native void clearCheckpointNative(
long nativeQueueProducerPtr,
long checkpointId
);
}
@@ -0,0 +1,22 @@
package io.ray.streaming.runtime.transfer;
import java.nio.ByteBuffer;
public interface Message {
/**
* Message data
* <p>
* Message body is a direct byte buffer, which may be invalid after call next
* <code>DataReader#getBundleNative</code>. Please consume this buffer fully
* before next call <code>getBundleNative</code>.
*
* @return message body
*/
ByteBuffer body();
/**
* @return timestamp when item is written by upstream DataWriter
*/
long timestamp();
}
@@ -1,60 +0,0 @@
package io.ray.streaming.runtime.transfer.channel;
import com.google.common.base.MoreObjects;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ChannelRecoverInfo implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(ChannelRecoverInfo.class);
public Map<String, ChannelCreationStatus> queueCreationStatusMap;
public ChannelRecoverInfo(Map<String, ChannelCreationStatus> queueCreationStatusMap) {
this.queueCreationStatusMap = queueCreationStatusMap;
}
public Set<String> getDataLostQueues() {
Set<String> dataLostQueues = new HashSet<>();
queueCreationStatusMap.forEach((q, status) -> {
if (status.equals(ChannelCreationStatus.DataLost)) {
dataLostQueues.add(q);
}
});
return dataLostQueues;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("dataLostQueues", getDataLostQueues())
.toString();
}
public enum ChannelCreationStatus {
FreshStarted(0),
PullOk(1),
Timeout(2),
DataLost(3);
private int id;
ChannelCreationStatus(int id) {
this.id = id;
}
public static ChannelCreationStatus fromInt(int id) {
for (ChannelCreationStatus status : ChannelCreationStatus.values()) {
if (status.id == id) {
return status;
}
}
return null;
}
}
}
@@ -1,31 +0,0 @@
package io.ray.streaming.runtime.transfer.channel;
import com.google.common.base.MoreObjects;
import java.io.Serializable;
/**
* This data structure contains offset used by streaming queue.
*/
public class OffsetInfo implements Serializable {
private long streamingMsgId;
public OffsetInfo(long streamingMsgId) {
this.streamingMsgId = streamingMsgId;
}
public long getStreamingMsgId() {
return streamingMsgId;
}
public void setStreamingMsgId(long streamingMsgId) {
this.streamingMsgId = streamingMsgId;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("streamingMsgId", streamingMsgId)
.toString();
}
}
@@ -1,22 +0,0 @@
package io.ray.streaming.runtime.transfer.exception;
import io.ray.streaming.runtime.transfer.DataReader;
import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import java.nio.ByteBuffer;
/**
* when {@link DataReader#stop()} or {@link DataWriter#stop()} is called, this exception might be
* thrown in {@link DataReader#read(long)} and {@link DataWriter#write(ChannelId, ByteBuffer)},
* which means the read/write operation is failed.
*/
public class ChannelInterruptException extends RuntimeException {
public ChannelInterruptException() {
super();
}
public ChannelInterruptException(String message) {
super(message);
}
}
@@ -1,34 +0,0 @@
package io.ray.streaming.runtime.transfer.message;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import java.nio.ByteBuffer;
import java.util.Map;
public class BarrierMessage extends ChannelMessage {
private final ByteBuffer data;
private final long checkpointId;
private final Map<String, OffsetInfo> inputOffsets;
public BarrierMessage(
long msgId, long timestamp, String channelId,
ByteBuffer data, long checkpointId, Map<String, OffsetInfo> inputOffsets) {
super(msgId, timestamp, channelId);
this.data = data;
this.checkpointId = checkpointId;
this.inputOffsets = inputOffsets;
}
public ByteBuffer getData() {
return data;
}
public long getCheckpointId() {
return checkpointId;
}
public Map<String, OffsetInfo> getInputOffsets() {
return inputOffsets;
}
}
@@ -1,26 +0,0 @@
package io.ray.streaming.runtime.transfer.message;
public class ChannelMessage {
private final long msgId;
private final long timestamp;
private final String channelId;
public ChannelMessage(long msgId, long timestamp, String channelId) {
this.msgId = msgId;
this.timestamp = timestamp;
this.channelId = channelId;
}
public long getMsgId() {
return msgId;
}
public long getTimestamp() {
return timestamp;
}
public String getChannelId() {
return channelId;
}
}
@@ -1,21 +0,0 @@
package io.ray.streaming.runtime.transfer.message;
import java.nio.ByteBuffer;
/**
* DataMessage represents data between upstream and downstream operators.
*/
public class DataMessage extends ChannelMessage {
private final ByteBuffer body;
public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) {
super(msgId, timestamp, channelId);
this.body = body;
}
public ByteBuffer body() {
return body;
}
}
@@ -1,59 +0,0 @@
package io.ray.streaming.runtime.util;
import io.ray.streaming.runtime.context.ContextBackend;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Handle exception for checkpoint state
*/
public class CheckpointStateUtil {
private static final Logger LOG = LoggerFactory.getLogger(CheckpointStateUtil.class);
/**
* DO NOT ALLOW GET EXCEPTION WHEN LOADING CHECKPOINT
*
* @param checkpointState state backend
* @param cpKey checkpoint key
*/
public static byte[] get(ContextBackend checkpointState, String cpKey) {
byte[] val;
try {
val = checkpointState.get(cpKey);
} catch (Exception e) {
throw new CheckpointStateRuntimeException(
String.format("Failed to get %s from state backend.", cpKey), e);
}
return val;
}
/**
* ALLOW PUT EXCEPTION WHEN SAVING CHECKPOINT
*
* @param checkpointState state backend
* @param key checkpoint key
* @param val checkpoint value
*/
public static void put(ContextBackend checkpointState, String key, byte[] val) {
try {
checkpointState.put(key, val);
} catch (Exception e) {
LOG.error("Failed to put key {} to state backend.", key, e);
}
}
public static class CheckpointStateRuntimeException extends RuntimeException {
public CheckpointStateRuntimeException() {
}
public CheckpointStateRuntimeException(String message) {
super(message);
}
public CheckpointStateRuntimeException(String message, Throwable cause) {
super(message, cause);
}
}
}
@@ -3,29 +3,13 @@ package io.ray.streaming.runtime.util;
import io.ray.runtime.RayNativeRuntime;
import io.ray.runtime.util.JniUtils;
import java.lang.management.ManagementFactory;
import java.net.InetAddress;
import java.net.UnknownHostException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class EnvUtil {
private static final Logger LOG = LoggerFactory.getLogger(EnvUtil.class);
public static String getJvmPid() {
return ManagementFactory.getRuntimeMXBean().getName().split("@")[0];
}
public static String getHostName() {
String hostname = "";
try {
hostname = InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
LOG.error("Error occurs while fetching local host.", e);
}
return hostname;
}
public static void loadNativeLibraries() {
// Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java`
// is loaded before `streaming_java`.
@@ -1,220 +0,0 @@
package io.ray.streaming.runtime.util;
import com.sun.management.OperatingSystemMXBean;
import io.ray.api.id.UniqueId;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.core.resource.ContainerId;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Resource Utility collects current OS and JVM resource usage information
*/
public class ResourceUtil {
public static final Logger LOG = LoggerFactory.getLogger(ResourceUtil.class);
/**
* Refer to: https://docs.oracle.com/javase/8/docs/jre/api/management/extension/com/sun/management/OperatingSystemMXBean.html
*/
private static OperatingSystemMXBean osmxb =
(OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean();
/**
* Log current jvm process's memory detail
*/
public static void logProcessMemoryDetail() {
int mb = 1024 * 1024;
//Getting the runtime reference from system
Runtime runtime = Runtime.getRuntime();
StringBuilder sb = new StringBuilder(32);
sb.append("used memory: ").append((runtime.totalMemory() - runtime.freeMemory()) / mb)
.append(", free memory: ").append(runtime.freeMemory() / mb)
.append(", total memory: ").append(runtime.totalMemory() / mb)
.append(", max memory: ").append(runtime.maxMemory() / mb);
if (LOG.isInfoEnabled()) {
LOG.info(sb.toString());
}
}
/**
* @return jvm heap usage ratio. note that one of the survivor space is not include in total
* memory while calculating this ratio.
*/
public static double getJvmHeapUsageRatio() {
Runtime runtime = Runtime.getRuntime();
return (runtime.totalMemory() - runtime.freeMemory()) * 1.0 / runtime.maxMemory();
}
/**
* @return jvm heap usage(in bytes).
* note that this value doesn't include one of the survivor space.
*/
public static long getJvmHeapUsageInBytes() {
Runtime runtime = Runtime.getRuntime();
return runtime.totalMemory() - runtime.freeMemory();
}
/**
* @return the total amount of physical memory in bytes.
*/
public static long getSystemTotalMemory() {
return osmxb.getTotalPhysicalMemorySize();
}
/**
* @return the used system physical memory in bytes
*/
public static long getSystemMemoryUsage() {
long totalMemory = osmxb.getTotalPhysicalMemorySize();
long freeMemory = osmxb.getFreePhysicalMemorySize();
return totalMemory - freeMemory;
}
/**
* @return the ratio of used system physical memory. This value is a double in the [0.0,1.0]
*/
public static double getSystemMemoryUsageRatio() {
double totalMemory = osmxb.getTotalPhysicalMemorySize();
double freeMemory = osmxb.getFreePhysicalMemorySize();
double ratio = freeMemory / totalMemory;
return 1 - ratio;
}
/**
* @return the cpu load for current jvm process. This value is a double in the [0.0,1.0]
*/
public static double getProcessCpuUsage() {
return osmxb.getProcessCpuLoad();
}
/**
* @return the system cpu usage.
* This value is a double in the [0.0,1.0]
* We will try to use `vsar` to get cpu usage by default,
* and use MXBean if any exception raised.
*/
public static double getSystemCpuUsage() {
double cpuUsage = 0.0;
try {
cpuUsage = getSystemCpuUtilByVsar();
} catch (Exception e) {
cpuUsage = getSystemCpuUtilByMXBean();
}
return cpuUsage;
}
/**
* Returns the "recent cpu usage" for the whole system. This value is a double in the [0.0,1.0]
* interval. A value of 0.0 means that all CPUs were idle during the recent period of time
* observed, while a value of 1.0 means that all CPUs were actively running 100% of the time
* during the recent period being observed
*/
public static double getSystemCpuUtilByMXBean() {
return osmxb.getSystemCpuLoad();
}
/**
* Get system cpu util by vsar
*/
public static double getSystemCpuUtilByVsar() throws Exception {
double cpuUsageFromVsar = 0.0;
String[] vsarCpuCommand = {"/bin/sh", "-c", "vsar --check --cpu -s util"};
try {
Process proc = Runtime.getRuntime().exec(vsarCpuCommand);
BufferedInputStream bis = new BufferedInputStream(proc.getInputStream());
BufferedReader br = new BufferedReader(new InputStreamReader(bis));
String line;
List<String> processPidList = new ArrayList<>();
while ((line = br.readLine()) != null) {
processPidList.add(line);
}
if (!processPidList.isEmpty()) {
String[] split = processPidList.get(0).split("=");
cpuUsageFromVsar = Double.parseDouble(split[1]) / 100.0D;
} else {
throw new IOException("Vsar check cpu usage failed, maybe vsar is not installed.");
}
} catch (Exception e) {
LOG.warn("Failed to get cpu usage by vsar.", e);
throw e;
}
return cpuUsageFromVsar;
}
/**
* @returns the system load average for the last minute
*/
public static double getSystemLoadAverage() {
return osmxb.getSystemLoadAverage();
}
/**
* @return system cpu cores num
*/
public static int getCpuCores() {
return osmxb.getAvailableProcessors();
}
/**
* Get containers by hostname of address
*
* @param containers container list
* @param containerHosts container hostname or address set
* @return matched containers
*/
public static List<Container> getContainersByHostname(
List<Container> containers,
Collection<String> containerHosts) {
return containers.stream()
.filter(container ->
containerHosts.contains(container.getHostname()) ||
containerHosts.contains(container.getAddress()))
.collect(Collectors.toList());
}
/**
* Get container by hostname
*
* @param hostName container hostname
* @return container
*/
public static Optional<Container> getContainerByHostname(
List<Container> containers,
String hostName) {
return containers.stream()
.filter(container -> container.getHostname().equals(hostName) ||
container.getAddress().equals(hostName))
.findFirst();
}
/**
* Get container by id
*
* @param containerID container id
* @return container
*/
public static Optional<Container> getContainerById(
List<Container> containers,
ContainerId containerID) {
return containers.stream()
.filter(container -> container.getId().equals(containerID))
.findFirst();
}
}
@@ -1,15 +0,0 @@
package io.ray.streaming.runtime.util;
import io.ray.runtime.serializer.FstSerializer;
public class Serializer {
public static byte[] encode(Object obj) {
return FstSerializer.encode(obj);
}
public static <T> T decode(byte[] bytes) {
return FstSerializer.decode(bytes);
}
}
@@ -1,32 +1,20 @@
package io.ray.streaming.runtime.worker;
import io.ray.api.Ray;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.config.types.TransferChannelType;
import io.ray.streaming.runtime.context.ContextBackend;
import io.ray.streaming.runtime.context.ContextBackendFactory;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.core.processor.OneInputProcessor;
import io.ray.streaming.runtime.core.processor.ProcessBuilder;
import io.ray.streaming.runtime.core.processor.SourceProcessor;
import io.ray.streaming.runtime.core.processor.StreamProcessor;
import io.ray.streaming.runtime.master.JobMaster;
import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
import io.ray.streaming.runtime.message.CallResult;
import io.ray.streaming.runtime.rpc.RemoteCallMaster;
import io.ray.streaming.runtime.transfer.TransferHandler;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus;
import io.ray.streaming.runtime.util.CheckpointStateUtil;
import io.ray.streaming.runtime.util.EnvUtil;
import io.ray.streaming.runtime.util.Serializer;
import io.ray.streaming.runtime.worker.context.JobWorkerContext;
import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
import io.ray.streaming.runtime.worker.tasks.SourceStreamTask;
import io.ray.streaming.runtime.worker.tasks.StreamTask;
import java.io.Serializable;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -48,223 +36,90 @@ public class JobWorker implements Serializable {
EnvUtil.loadNativeLibraries();
}
public final Object initialStateChangeLock = new Object();
/**
* isRecreate=true means this worker is initialized more than once after actor created.
*/
public AtomicBoolean isRecreate = new AtomicBoolean(false);
public ContextBackend contextBackend;
private JobWorkerContext workerContext;
private ExecutionVertex executionVertex;
private StreamingWorkerConfig workerConfig;
/**
* The while-loop thread to read message, process message, and write results
*/
private StreamTask task;
/**
* transferHandler handles messages by ray direct call
*/
private TransferHandler transferHandler;
/**
* A flag to avoid duplicated rollback. Becomes true after requesting
* rollback, set to false when finish rollback.
*/
private boolean isNeedRollback = false;
private int rollbackCount = 0;
public JobWorker(ExecutionVertex executionVertex) {
LOG.info("Creating job worker.");
// TODO: the following 3 lines is duplicated with that in init(), try to optimise it later.
this.executionVertex = executionVertex;
this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig());
this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig);
LOG.info("Ray.getRuntimeContext().wasCurrentActorRestarted()={}",
Ray.getRuntimeContext().wasCurrentActorRestarted());
if (!Ray.getRuntimeContext().wasCurrentActorRestarted()) {
saveContext();
LOG.info("Job worker is fresh started, init success.");
return;
}
LOG.info("Begin load job worker checkpoint state.");
byte[] bytes = CheckpointStateUtil.get(contextBackend, getJobWorkerContextKey());
if (bytes != null) {
JobWorkerContext context = Serializer.decode(bytes);
LOG.info("Worker recover from checkpoint state, byte len={}, context={}.", bytes.length,
context);
init(context);
requestRollback("LoadCheckpoint request rollback in new actor.");
} else {
LOG.error(
"Worker is reconstructed, but can't load checkpoint. " +
"Check whether you checkpoint state is reliable. Current checkpoint state is {}.",
contextBackend.getClass().getName());
}
}
public synchronized void saveContext() {
byte[] contextBytes = Serializer.encode(workerContext);
String key = getJobWorkerContextKey();
LOG.info("Saving context, worker context={}, serialized byte length={}, key={}.", workerContext,
contextBytes.length, key);
CheckpointStateUtil.put(contextBackend, key, contextBytes);
public JobWorker() {
LOG.info("Creating job worker succeeded.");
}
/**
* Initialize JobWorker and data communication pipeline.
*/
public Boolean init(JobWorkerContext workerContext) {
// IMPORTANT: some test cases depends on this log to find workers' pid,
// be careful when changing this log.
LOG.info("Initiating job worker: {}. Worker context is: {}, pid={}.",
workerContext.getWorkerName(), workerContext, EnvUtil.getJvmPid());
this.workerContext = workerContext;
this.executionVertex = workerContext.getExecutionVertex();
this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig());
// init state backend
this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig);
LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName());
saveContext();
return true;
}
/**
* Start worker's stream tasks with specific checkpoint ID.
*
* @return a {@link CallResult} with {@link ChannelRecoverInfo},
* contains {@link ChannelCreationStatus} of each input queue.
*/
public CallResult<ChannelRecoverInfo> rollback(Long checkpointId, Long startRollbackTs) {
synchronized (initialStateChangeLock) {
if (task != null && task.isAlive() && checkpointId == task.lastCheckpointId &&
task.isInitialState) {
return CallResult.skipped("Task is already in initial state, skip this rollback.");
}
}
long remoteCallCost = System.currentTimeMillis() - startRollbackTs;
LOG.info("Start rollback[{}], checkpoint is {}, remote call cost {}ms.",
executionVertex.getExecutionJobVertexName(), checkpointId, remoteCallCost);
rollbackCount++;
if (rollbackCount > 1) {
isRecreate.set(true);
}
LOG.info("Initiating job worker: {}. Worker context is: {}.",
workerContext.getWorkerName(), workerContext);
try {
this.workerContext = workerContext;
this.executionVertex = workerContext.getExecutionVertex();
this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig());
//Init transfer
TransferChannelType channelType = workerConfig.transferConfig.channelType();
if (TransferChannelType.NATIVE_CHANNEL == channelType) {
transferHandler = new TransferHandler();
}
if (task != null) {
// make sure the task is closed
task.close();
task = null;
}
// create stream task
task = createStreamTask(checkpointId);
ChannelRecoverInfo channelRecoverInfo = task.recover(isRecreate.get());
isNeedRollback = false;
LOG.info("Rollback job worker success, checkpoint is {}, channelRecoverInfo is {}.",
checkpointId, channelRecoverInfo);
return CallResult.success(channelRecoverInfo);
task = createStreamTask();
if (task == null) {
return false;
}
} catch (Exception e) {
LOG.error("Rollback job worker has exception.", e);
return CallResult.fail(ExceptionUtils.getStackTrace(e));
LOG.error("Failed to initiate job worker.", e);
return false;
}
LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName());
return true;
}
/**
* Start worker's stream tasks.
*
* @return result
*/
public Boolean start() {
try {
task.start();
} catch (Exception e) {
LOG.error("Start worker [{}] occur error.", executionVertex.getExecutionVertexName(), e);
return false;
}
return true;
}
/**
* Create tasks based on the processor corresponding of the operator.
*/
private StreamTask createStreamTask(long checkpointId) {
StreamTask task;
private StreamTask createStreamTask() {
StreamTask task = null;
StreamProcessor streamProcessor = ProcessBuilder
.buildProcessor(executionVertex.getStreamOperator());
LOG.debug("Stream processor created: {}.", streamProcessor);
if (streamProcessor instanceof SourceProcessor) {
task = new SourceStreamTask(streamProcessor, this, checkpointId);
} else if (streamProcessor instanceof OneInputProcessor) {
task = new OneInputStreamTask(streamProcessor, this, checkpointId);
} else {
throw new RuntimeException("Unsupported processor type:" + streamProcessor);
try {
if (streamProcessor instanceof SourceProcessor) {
task = new SourceStreamTask(getTaskId(), streamProcessor, this);
} else if (streamProcessor instanceof OneInputProcessor) {
task = new OneInputStreamTask(getTaskId(), streamProcessor, this);
} else {
throw new RuntimeException("Unsupported processor type:" + streamProcessor);
}
} catch (Exception e) {
LOG.info("Failed to create stream task.", e);
return task;
}
LOG.info("Stream task created: {}.", task);
return task;
}
// ----------------------------------------------------------------------
// Checkpoint
// ----------------------------------------------------------------------
/**
* Trigger source job worker checkpoint
*/
public Boolean triggerCheckpoint(Long barrierId) {
LOG.info("Receive trigger, barrierId is {}.", barrierId);
if (task != null) {
return task.triggerCheckpoint(barrierId);
}
return false;
}
public Boolean notifyCheckpointTimeout(Long checkpointId) {
LOG.info("Notify checkpoint timeout, checkpoint id is {}.", checkpointId);
if (task != null) {
task.notifyCheckpointTimeout(checkpointId);
}
return true;
}
public Boolean clearExpiredCheckpoint(Long expiredStateCpId, Long expiredQueueCpId) {
LOG.info("Clear expired checkpoint state, checkpoint id is {}; " +
"Clear expired queue msg, checkpoint id is {}",
expiredStateCpId, expiredQueueCpId);
if (task != null) {
if (expiredStateCpId > 0) {
task.clearExpiredCpState(expiredStateCpId);
}
task.clearExpiredQueueMsg(expiredQueueCpId);
}
return true;
}
// ----------------------------------------------------------------------
// Failover
// ----------------------------------------------------------------------
public void requestRollback(String exceptionMsg) {
LOG.info("Request rollback.");
isNeedRollback = true;
isRecreate.set(true);
boolean requestRet = RemoteCallMaster.requestJobWorkerRollback(
workerContext.getMaster(), new WorkerRollbackRequest(
workerContext.getWorkerActorId(),
exceptionMsg,
EnvUtil.getHostName(),
EnvUtil.getJvmPid()
));
if (!requestRet) {
LOG.warn("Job worker request rollback failed! exceptionMsg={}.", exceptionMsg);
}
}
public Boolean checkIfNeedRollback(Long startCallTs) {
// No save checkpoint in this query.
long remoteCallCost = System.currentTimeMillis() - startCallTs;
LOG.info("Finished checking if need to rollback with result: {}, rpc delay={}ms.",
isNeedRollback, remoteCallCost);
return isNeedRollback;
public int getTaskId() {
return executionVertex.getExecutionVertexId();
}
public StreamingWorkerConfig getWorkerConfig() {
@@ -283,19 +138,11 @@ public class JobWorker implements Serializable {
return task;
}
private String getJobWorkerContextKey() {
return workerConfig.checkpointConfig.jobWorkerContextCpPrefixKey()
+ workerConfig.commonConfig.jobName()
+ "_" + executionVertex.getExecutionVertexId();
}
/**
* Used by upstream streaming queue to send data to this actor
*/
public void onReaderMessage(byte[] buffer) {
if (transferHandler != null) {
transferHandler.onReaderMessage(buffer);
}
transferHandler.onReaderMessage(buffer);
}
/**
@@ -312,9 +159,7 @@ public class JobWorker implements Serializable {
* Used by downstream streaming queue to send data to this actor
*/
public void onWriterMessage(byte[] buffer) {
if (transferHandler != null) {
transferHandler.onWriterMessage(buffer);
}
transferHandler.onWriterMessage(buffer);
}
/**
@@ -327,5 +172,4 @@ public class JobWorker implements Serializable {
}
return transferHandler.onWriterMessageSync(buffer);
}
}
@@ -3,9 +3,7 @@ package io.ray.streaming.runtime.worker.context;
import com.google.common.base.MoreObjects;
import com.google.protobuf.ByteString;
import io.ray.api.ActorHandle;
import io.ray.api.id.ActorId;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.streaming.runtime.config.global.CommonConfig;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.master.JobMaster;
@@ -35,10 +33,6 @@ public class JobWorkerContext implements Serializable {
this.executionVertex = executionVertex;
}
public ActorId getWorkerActorId() {
return executionVertex.getWorkerActorId();
}
public int getWorkerId() {
return executionVertex.getExecutionVertexId();
}
@@ -59,14 +53,6 @@ public class JobWorkerContext implements Serializable {
return executionVertex;
}
public Map<String, String> getConf() {
return getExecutionVertex().getWorkerConfig();
}
public String getJobName() {
return getConf().get(CommonConfig.JOB_NAME);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
@@ -2,31 +2,22 @@ package io.ray.streaming.runtime.worker.tasks;
import com.google.common.base.MoreObjects;
import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.serialization.CrossLangSerializer;
import io.ray.streaming.runtime.serialization.JavaSerializer;
import io.ray.streaming.runtime.serialization.Serializer;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException;
import io.ray.streaming.runtime.transfer.message.BarrierMessage;
import io.ray.streaming.runtime.transfer.message.ChannelMessage;
import io.ray.streaming.runtime.transfer.message.DataMessage;
import io.ray.streaming.runtime.transfer.Message;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.Map;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class InputStreamTask extends StreamTask {
private static final Logger LOG = LoggerFactory.getLogger(InputStreamTask.class);
private volatile boolean running = true;
private volatile boolean stopped = false;
private long readTimeoutMillis;
private final io.ray.streaming.runtime.serialization.Serializer javaSerializer;
private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer;
private final long readTimeoutMillis;
public InputStreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) {
super(processor, jobWorker, lastCheckpointId);
public InputStreamTask(int taskId, Processor processor, JobWorker jobWorker) {
super(taskId, processor, jobWorker);
readTimeoutMillis = jobWorker.getWorkerConfig().transferConfig.readerTimerIntervalMs();
javaSerializer = new JavaSerializer();
crossLangSerializer = new CrossLangSerializer();
@@ -38,64 +29,35 @@ public abstract class InputStreamTask extends StreamTask {
@Override
public void run() {
try {
while (running) {
ChannelMessage item;
// reader.read() will change the consumer state once it got an item. This lock is to
// ensure worker can get correct isInitialState value in exactly-once-mode's rollback.
synchronized (jobWorker.initialStateChangeLock) {
item = reader.read(readTimeoutMillis);
if (item != null) {
isInitialState = false;
} else {
continue;
}
while (running) {
Message item = reader.read(readTimeoutMillis);
if (item != null) {
byte[] bytes = new byte[item.body().remaining() - 1];
byte typeId = item.body().get();
item.body().get(bytes);
Object obj;
if (typeId == Serializer.JAVA_TYPE_ID) {
obj = javaSerializer.deserialize(bytes);
} else {
obj = crossLangSerializer.deserialize(bytes);
}
if (item instanceof DataMessage) {
DataMessage dataMessage = (DataMessage) item;
byte[] bytes = new byte[dataMessage.body().remaining() - 1];
byte typeId = dataMessage.body().get();
dataMessage.body().get(bytes);
Object obj;
if (typeId == Serializer.JAVA_TYPE_ID) {
obj = javaSerializer.deserialize(bytes);
} else {
obj = crossLangSerializer.deserialize(bytes);
}
processor.process(obj);
} else if (item instanceof BarrierMessage) {
final BarrierMessage queueBarrier = (BarrierMessage) item;
byte[] barrierData = new byte[queueBarrier.getData().remaining()];
queueBarrier.getData().get(barrierData);
RemoteCall.Barrier barrierPb = RemoteCall.Barrier.parseFrom(barrierData);
final long checkpointId = barrierPb.getId();
LOG.info("Start to do checkpoint {}, worker name is {}.", checkpointId,
jobWorker.getWorkerContext().getWorkerName());
final Map<String, OffsetInfo> inputPoints = queueBarrier.getInputOffsets();
doCheckpoint(checkpointId, inputPoints);
LOG.info("Do checkpoint {} success.", checkpointId);
}
}
} catch (Throwable throwable) {
if (throwable instanceof ChannelInterruptException ||
ExceptionUtils.getRootCause(throwable) instanceof ChannelInterruptException) {
LOG.info("queue has stopped.");
} else {
// error occurred, need to rollback
LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, throwable);
requestRollback(ExceptionUtils.getStackTrace(throwable));
processor.process(obj);
}
}
LOG.info("Input stream task thread exit.");
stopped = true;
}
@Override
protected void cancelTask() throws Exception {
running = false;
while (!stopped) {
}
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("taskId", taskId)
.add("processor", processor)
.toString();
}
@@ -8,7 +8,7 @@ import io.ray.streaming.runtime.worker.JobWorker;
*/
public class OneInputStreamTask extends InputStreamTask {
public OneInputStreamTask(Processor inputProcessor, JobWorker jobWorker, long lastCheckpointId) {
super(inputProcessor, jobWorker, lastCheckpointId);
public OneInputStreamTask(int taskId, Processor inputProcessor, JobWorker jobWorker) {
super(taskId, inputProcessor, jobWorker);
}
}
@@ -3,10 +3,7 @@ package io.ray.streaming.runtime.worker.tasks;
import io.ray.streaming.operator.SourceOperator;
import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.core.processor.SourceProcessor;
import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -16,19 +13,12 @@ public class SourceStreamTask extends StreamTask {
private final SourceProcessor sourceProcessor;
/**
* The pending barrier ID to be triggered.
*/
private final AtomicReference<Long> pendingBarrier = new AtomicReference<>();
private long lastCheckpointId = 0;
/**
* SourceStreamTask for executing a {@link SourceOperator}. It is responsible for running the
* corresponding source operator.
*/
public SourceStreamTask(Processor sourceProcessor, JobWorker jobWorker, long lastCheckpointId) {
super(sourceProcessor, jobWorker, lastCheckpointId);
public SourceStreamTask(int taskId, Processor sourceProcessor, JobWorker jobWorker) {
super(taskId, sourceProcessor, jobWorker);
this.sourceProcessor = (SourceProcessor) processor;
}
@@ -39,48 +29,12 @@ public class SourceStreamTask extends StreamTask {
@Override
public void run() {
LOG.info("Source stream task thread start.");
Long barrierId;
try {
while (running) {
isInitialState = false;
// check checkpoint
barrierId = pendingBarrier.get();
if (barrierId != null) {
// Important: because cp maybe timeout, master will use the old checkpoint id again
if (pendingBarrier.compareAndSet(barrierId, null)) {
// source fetcher only have outputPoints
LOG.info("Start to do checkpoint {}, worker name is {}.",
barrierId, jobWorker.getWorkerContext().getWorkerName());
doCheckpoint(barrierId, null);
LOG.info("Finish to do checkpoint {}.", barrierId);
} else {
// pendingCheckpointId has modify, should not happen
LOG.warn("Pending checkpointId modify unexpected, expect={}, now={}.", barrierId,
pendingBarrier.get());
}
}
sourceProcessor.fetch();
}
} catch (Throwable e) {
if (e instanceof ChannelInterruptException ||
ExceptionUtils.getRootCause(e) instanceof ChannelInterruptException) {
LOG.info("queue has stopped.");
} else {
// occur error, need to rollback
LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, e);
requestRollback(ExceptionUtils.getStackTrace(e));
}
}
LOG.info("Source stream task thread exit.");
sourceProcessor.run();
}
@Override
public boolean triggerCheckpoint(Long barrierId) {
return pendingBarrier.compareAndSet(null, barrierId);
protected void cancelTask() {
}
}
@@ -6,103 +6,53 @@ import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.runtime.config.worker.WorkerInternalConfig;
import io.ray.streaming.runtime.context.ContextBackend;
import io.ray.streaming.runtime.context.OperatorCheckpointInfo;
import io.ray.streaming.runtime.core.collector.OutputCollector;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.core.processor.Processor;
import io.ray.streaming.runtime.generated.RemoteCall;
import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
import io.ray.streaming.runtime.rpc.RemoteCallMaster;
import io.ray.streaming.runtime.transfer.ChannelId;
import io.ray.streaming.runtime.transfer.DataReader;
import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
import io.ray.streaming.runtime.transfer.channel.OffsetInfo;
import io.ray.streaming.runtime.util.CheckpointStateUtil;
import io.ray.streaming.runtime.util.Serializer;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.JobWorkerContext;
import io.ray.streaming.runtime.worker.context.StreamingRuntimeContext;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link StreamTask} is a while-loop thread to read message, process message, and send result
* messages to downstream operators
*/
public abstract class StreamTask implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
private final ContextBackend checkpointState;
public volatile boolean isInitialState = true;
public long lastCheckpointId;
protected int taskId;
protected Processor processor;
protected JobWorker jobWorker;
protected DataReader reader;
protected DataWriter writer;
List<Collector> collectors = new ArrayList<>();
protected volatile boolean running = true;
protected volatile boolean stopped = false;
List<Collector> collectors = new ArrayList<>();
private Set<Long> outdatedCheckpoints = new HashSet<>();
private Thread thread;
protected StreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) {
protected StreamTask(int taskId, Processor processor, JobWorker jobWorker) {
this.taskId = taskId;
this.processor = processor;
this.jobWorker = jobWorker;
this.checkpointState = jobWorker.contextBackend;
this.lastCheckpointId = lastCheckpointId;
prepareTask();
this.thread = new Thread(Ray.wrapRunnable(this),
this.getClass().getName() + "-" + System.currentTimeMillis());
this.thread.setDaemon(true);
}
public ChannelRecoverInfo recover(boolean isRecover) {
if (isRecover) {
LOG.info("Stream task begin recover.");
} else {
LOG.info("Stream task first start begin.");
}
prepareTask(isRecover);
// start runner
ChannelRecoverInfo recoverInfo = new ChannelRecoverInfo(new HashMap<>());
if (reader != null) {
recoverInfo = reader.getQueueRecoverInfo();
}
thread.setUncaughtExceptionHandler(
(t, e) -> LOG.error("Uncaught exception in runner thread.", e));
LOG.info("Start stream task: {}.", this.getClass().getSimpleName());
thread.start();
if (isRecover) {
LOG.info("Stream task recover end.");
} else {
LOG.info("Stream task first start finished.");
}
return recoverInfo;
}
/**
* Load checkpoint and build upstream and downstream data transmission
* channels according to {@link ExecutionVertex}.
* Build upstream and downstream data transmission channels according to {@link ExecutionVertex}.
*/
private void prepareTask(boolean isRecreate) {
LOG.info("Preparing stream task, isRecreate={}.", isRecreate);
private void prepareTask() {
LOG.debug("Preparing stream task.");
ExecutionVertex executionVertex = jobWorker.getExecutionVertex();
// set vertex info into config for native using
@@ -111,92 +61,73 @@ public abstract class StreamTask implements Runnable {
jobWorker.getWorkerConfig().workerInternalConfig.setProperty(
WorkerInternalConfig.OP_NAME_INTERNAL, executionVertex.getExecutionJobVertexName());
OperatorCheckpointInfo operatorCheckpointInfo = new OperatorCheckpointInfo();
byte[] bytes = null;
// producer
// Fetch checkpoint from storage only in recreate mode not for new startup worker
// in rescaling or something like that.
if (isRecreate) {
String cpKey = genOpCheckpointKey(lastCheckpointId);
LOG.info("Getting task checkpoints from state, cpKey={}, checkpointId={}.", cpKey,
lastCheckpointId);
bytes = CheckpointStateUtil.get(checkpointState, cpKey);
if (bytes == null) {
String msg = String.format("Task recover failed, checkpoint is null! cpKey=%s", cpKey);
throw new RuntimeException(msg);
}
}
// when use memory state, if actor throw exception, will miss state
if (bytes != null) {
operatorCheckpointInfo = Serializer.decode(bytes);
processor.loadCheckpoint(operatorCheckpointInfo.processorCheckpoint);
LOG.info(
"Stream task recover from checkpoint state, checkpoint bytes len={}, checkpointInfo={}.",
bytes.length, operatorCheckpointInfo);
}
// writer
if (!executionVertex.getOutputEdges().isEmpty()) {
LOG.info("Register queue writer, channels={}, outputCheckpoints={}.",
executionVertex.getOutputChannelIdList(), operatorCheckpointInfo.outputPoints);
writer = new DataWriter(
executionVertex.getOutputChannelIdList(),
executionVertex.getOutputActorList(),
operatorCheckpointInfo.outputPoints,
jobWorker.getWorkerConfig()
);
}
// reader
if (!executionVertex.getInputEdges().isEmpty()) {
LOG.info("Register queue reader, channels={}, inputCheckpoints={}.",
executionVertex.getInputChannelIdList(), operatorCheckpointInfo.inputPoints);
reader = new DataReader(
executionVertex.getInputChannelIdList(),
executionVertex.getInputActorList(),
operatorCheckpointInfo.inputPoints,
jobWorker.getWorkerConfig()
);
}
openProcessor();
LOG.debug("Finished preparing stream task.");
}
/**
* Create one collector for each distinct output operator(i.e. each {@link ExecutionJobVertex})
*/
private void openProcessor() {
ExecutionVertex executionVertex = jobWorker.getExecutionVertex();
List<ExecutionEdge> outputEdges = executionVertex.getOutputEdges();
Map<String, List<String>> opGroupedChannelId = new HashMap<>();
Map<String, List<BaseActorHandle>> opGroupedActor = new HashMap<>();
Map<String, Partition> opPartitionMap = new HashMap<>();
for (int i = 0; i < outputEdges.size(); ++i) {
ExecutionEdge edge = outputEdges.get(i);
String opName = edge.getTargetExecutionJobVertexName();
if (!opPartitionMap.containsKey(opName)) {
opGroupedChannelId.put(opName, new ArrayList<>());
opGroupedActor.put(opName, new ArrayList<>());
}
opGroupedChannelId.get(opName).add(executionVertex.getOutputChannelIdList().get(i));
opGroupedActor.get(opName).add(executionVertex.getOutputActorList().get(i));
opPartitionMap.put(opName, edge.getPartition());
// merge all output edges to create writer
List<String> outputChannelIds = new ArrayList<>();
List<BaseActorHandle> targetActors = new ArrayList<>();
for (ExecutionEdge edge : outputEdges) {
String channelId = ChannelId.genIdStr(
taskId,
edge.getTargetExecutionVertex().getExecutionVertexId(),
executionVertex.getBuildTime());
outputChannelIds.add(channelId);
targetActors.add(edge.getTargetExecutionVertex().getWorkerActor());
}
if (!targetActors.isEmpty()) {
DataWriter writer = new DataWriter(
outputChannelIds, targetActors, jobWorker.getWorkerConfig()
);
// create a collector for each output operator
Map<String, List<String>> opGroupedChannelId = new HashMap<>();
Map<String, List<BaseActorHandle>> opGroupedActor = new HashMap<>();
Map<String, Partition> opPartitionMap = new HashMap<>();
for (int i = 0; i < outputEdges.size(); ++i) {
ExecutionEdge edge = outputEdges.get(i);
String opName = edge.getTargetExecutionJobVertexName();
if (!opPartitionMap.containsKey(opName)) {
opGroupedChannelId.put(opName, new ArrayList<>());
opGroupedActor.put(opName, new ArrayList<>());
}
opGroupedChannelId.get(opName).add(outputChannelIds.get(i));
opGroupedActor.get(opName).add(targetActors.get(i));
opPartitionMap.put(opName, edge.getPartition());
}
opPartitionMap.keySet().forEach(opName -> {
collectors.add(new OutputCollector(
writer, opGroupedChannelId.get(opName),
opGroupedActor.get(opName), opPartitionMap.get(opName)
));
});
}
// consumer
List<ExecutionEdge> inputEdges = executionVertex.getInputEdges();
List<String> inputChannelIds = new ArrayList<>();
List<BaseActorHandle> inputActors = new ArrayList<>();
for (ExecutionEdge edge : inputEdges) {
String queueName = ChannelId.genIdStr(
edge.getSourceExecutionVertex().getExecutionVertexId(),
taskId,
executionVertex.getBuildTime());
inputChannelIds.add(queueName);
inputActors.add(edge.getSourceExecutionVertex().getWorkerActor());
}
if (!inputActors.isEmpty()) {
LOG.info("Register queue consumer, channels {}.", inputChannelIds);
reader = new DataReader(inputChannelIds, inputActors, jobWorker.getWorkerConfig());
}
opPartitionMap.keySet().forEach(opName -> {
collectors.add(new OutputCollector(
writer, opGroupedChannelId.get(opName),
opGroupedActor.get(opName), opPartitionMap.get(opName)
));
});
RuntimeContext runtimeContext = new StreamingRuntimeContext(executionVertex,
jobWorker.getWorkerConfig().configMap, executionVertex.getParallelism());
processor.open(collectors, runtimeContext);
LOG.debug("Finished preparing stream task.");
}
/**
@@ -204,6 +135,16 @@ public abstract class StreamTask implements Runnable {
*/
protected abstract void init() throws Exception;
/**
* Stop running tasks.
*/
protected abstract void cancelTask() throws Exception;
public void start() {
LOG.info("Start stream task: {}-{}", this.getClass().getSimpleName(), taskId);
this.thread.start();
}
/**
* Close running tasks.
*/
@@ -218,134 +159,4 @@ public abstract class StreamTask implements Runnable {
LOG.info("Stream task close success.");
}
// ----------------------------------------------------------------------
// Checkpoint
// ----------------------------------------------------------------------
public boolean triggerCheckpoint(Long barrierId) {
throw new UnsupportedOperationException("Only source operator supports trigger checkpoints.");
}
public void doCheckpoint(long checkpointId, Map<String, OffsetInfo> inputPoints) {
Map<String, OffsetInfo> outputPoints = null;
if (writer != null) {
outputPoints = writer.getOutputCheckpoints();
RemoteCall.Barrier barrierPb =
RemoteCall.Barrier.newBuilder().setId(checkpointId).build();
ByteBuffer byteBuffer = ByteBuffer.wrap(barrierPb.toByteArray());
byteBuffer.order(ByteOrder.nativeOrder());
writer.broadcastBarrier(checkpointId, byteBuffer);
}
LOG.info("Start do checkpoint, cp id={}, inputPoints={}, outputPoints={}.", checkpointId,
inputPoints, outputPoints);
this.lastCheckpointId = checkpointId;
Serializable processorCheckpoint = processor.saveCheckpoint();
try {
OperatorCheckpointInfo opCpInfo =
new OperatorCheckpointInfo(inputPoints, outputPoints, processorCheckpoint,
checkpointId);
saveCpStateAndReport(opCpInfo, checkpointId);
} catch (Exception e) {
// there will be exceptions when flush state to backend.
// we ignore the exception to prevent failover
LOG.error("Processor or op checkpoint exception.", e);
}
LOG.info("Operator do checkpoint {} finish.", checkpointId);
}
private void saveCpStateAndReport(
OperatorCheckpointInfo operatorCheckpointInfo,
long checkpointId) {
saveCp(operatorCheckpointInfo, checkpointId);
reportCommit(checkpointId);
LOG.info("Finish save cp state and report, checkpoint id is {}.", checkpointId);
}
private void saveCp(OperatorCheckpointInfo operatorCheckpointInfo, long checkpointId) {
byte[] bytes = Serializer.encode(operatorCheckpointInfo);
String cpKey = genOpCheckpointKey(checkpointId);
LOG.info("Saving task checkpoint, cpKey={}, byte len={}, checkpointInfo={}.", cpKey,
bytes.length, operatorCheckpointInfo);
synchronized (checkpointState) {
if (outdatedCheckpoints.contains(checkpointId)) {
LOG.info("Outdated checkpoint, skip save checkpoint.");
outdatedCheckpoints.remove(checkpointId);
} else {
CheckpointStateUtil.put(checkpointState, cpKey, bytes);
}
}
}
private void reportCommit(long checkpointId) {
final JobWorkerContext context = jobWorker.getWorkerContext();
LOG.info("Report commit async, checkpoint id {}.", checkpointId);
RemoteCallMaster.reportJobWorkerCommitAsync(context.getMaster(),
new WorkerCommitReport(context.getWorkerActorId(), checkpointId));
}
public void notifyCheckpointTimeout(long checkpointId) {
String cpKey = genOpCheckpointKey(checkpointId);
try {
synchronized (checkpointState) {
if (checkpointState.exists(cpKey)) {
checkpointState.remove(cpKey);
} else {
outdatedCheckpoints.add(checkpointId);
}
}
} catch (Exception e) {
LOG.error("Notify checkpoint timeout failed, checkpointId is {}.", checkpointId, e);
}
}
public void clearExpiredCpState(long checkpointId) {
String cpKey = genOpCheckpointKey(checkpointId);
try {
checkpointState.remove(cpKey);
} catch (Exception e) {
LOG.error("Failed to remove key {} from state backend.", cpKey, e);
}
}
public void clearExpiredQueueMsg(long checkpointId) {
// get operator checkpoint
String cpKey = genOpCheckpointKey(checkpointId);
byte[] bytes;
try {
bytes = checkpointState.get(cpKey);
} catch (Exception e) {
LOG.error("Failed to get key {} from state backend.", cpKey, e);
return;
}
if (bytes != null) {
final OperatorCheckpointInfo operatorCheckpointInfo = Serializer.decode(bytes);
long cpId = operatorCheckpointInfo.checkpointId;
if (writer != null) {
writer.clearCheckpoint(cpId);
}
}
}
public String genOpCheckpointKey(long checkpointId) {
// TODO: need to support job restart and actorId changed
final JobWorkerContext context = jobWorker.getWorkerContext();
return jobWorker.getWorkerConfig().checkpointConfig.jobWorkerOpCpPrefixKey()
+ context.getJobName() + "_" + context.getWorkerName() + "_" + checkpointId;
}
// ----------------------------------------------------------------------
// Failover
// ----------------------------------------------------------------------
protected void requestRollback(String exceptionMsg) {
jobWorker.requestRollback(exceptionMsg);
}
public boolean isAlive() {
return this.thread.isAlive();
}
}
@@ -10,12 +10,12 @@ import io.ray.streaming.runtime.worker.JobWorker;
public class TwoInputStreamTask extends InputStreamTask {
public TwoInputStreamTask(
int taskId,
Processor processor,
JobWorker jobWorker,
String leftStream,
String rightStream,
long lastCheckpointId) {
super(processor, jobWorker, lastCheckpointId);
String rightStream) {
super(taskId, processor, jobWorker);
((TwoInputProcessor) (super.processor)).setLeftStream(leftStream);
((TwoInputProcessor) (super.processor)).setRightStream(rightStream);
}
@@ -15,7 +15,7 @@ import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
import io.ray.streaming.runtime.core.resource.ResourceType;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.JobRuntimeContext;
import io.ray.streaming.runtime.master.graphmanager.GraphManager;
import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl;
import java.util.HashMap;
@@ -34,7 +34,7 @@ public class ExecutionGraphTest extends BaseUnitTest {
public void testBuildExecutionGraph() {
Map<String, String> jobConf = new HashMap<>();
StreamingConfig streamingConfig = new StreamingConfig(jobConf);
GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig));
GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig));
JobGraph jobGraph = buildJobGraph();
jobGraph.getJobConfig().put("streaming.task.resource.cpu.limitation.enable", "true");
@@ -36,7 +36,7 @@ public class UnionStreamTest {
streamSource1
.union(streamSource2, streamSource3)
.sink((SinkFunction<Integer>) value -> {
LOG.info("UnionStreamTest, sink: {}", value);
LOG.info("UnionStreamTest: {}", value);
try {
if (!Files.exists(Paths.get(sinkFileName))) {
Files.createFile(Paths.get(sinkFileName));
@@ -14,7 +14,7 @@ public class JobMasterTest {
Assert.assertNull(jobMaster.getGraphManager());
Assert.assertNull(jobMaster.getResourceManager());
Assert.assertNull(jobMaster.getJobMasterActor());
Assert.assertFalse(jobMaster.init(false));
Assert.assertFalse(jobMaster.init());
}
}
@@ -7,7 +7,7 @@ import io.ray.streaming.runtime.BaseUnitTest;
import io.ray.streaming.runtime.config.StreamingConfig;
import io.ray.streaming.runtime.config.global.CommonConfig;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.JobRuntimeContext;
import io.ray.streaming.runtime.util.RayUtils;
import java.util.HashMap;
import java.util.List;
@@ -44,8 +44,8 @@ public class ResourceManagerTest extends BaseUnitTest {
Map<String, String> conf = new HashMap<String, String>();
conf.put(CommonConfig.JOB_NAME, "testApi");
StreamingConfig config = new StreamingConfig(conf);
JobMasterRuntimeContext jobMasterRuntimeContext = new JobMasterRuntimeContext(config);
ResourceManager resourceManager = new ResourceManagerImpl(jobMasterRuntimeContext);
JobRuntimeContext jobRuntimeContext = new JobRuntimeContext(config);
ResourceManager resourceManager = new ResourceManagerImpl(jobRuntimeContext);
// test register container
List<Container> containers = resourceManager.getRegisteredContainers();
@@ -9,7 +9,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionGraphTest;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.core.resource.ResourceType;
import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.JobRuntimeContext;
import io.ray.streaming.runtime.master.graphmanager.GraphManager;
import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl;
import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView;
@@ -64,7 +64,7 @@ public class PipelineFirstStrategyTest extends BaseUnitTest {
Map<String, String> jobConf = new HashMap<>();
StreamingConfig streamingConfig = new StreamingConfig(jobConf);
GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig));
GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig));
JobGraph jobGraph = ExecutionGraphTest.buildJobGraph();
ExecutionGraph executionGraph = ExecutionGraphTest.buildExecutionGraph(graphManager, jobGraph);
ResourceAssignmentView assignmentView = strategy.assignResource(containers, executionGraph);
@@ -9,7 +9,7 @@ import io.ray.streaming.api.function.impl.FlatMapFunction;
import io.ray.streaming.api.function.impl.ReduceFunction;
import io.ray.streaming.api.stream.DataStreamSource;
import io.ray.streaming.runtime.BaseUnitTest;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.transfer.ChannelId;
import io.ray.streaming.runtime.util.EnvUtil;
import io.ray.streaming.util.Config;
import java.io.File;
@@ -6,11 +6,11 @@ import io.ray.api.Ray;
import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.streaming.runtime.config.StreamingWorkerConfig;
import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder;
import io.ray.streaming.runtime.transfer.ChannelId;
import io.ray.streaming.runtime.transfer.DataMessage;
import io.ray.streaming.runtime.transfer.DataReader;
import io.ray.streaming.runtime.transfer.DataWriter;
import io.ray.streaming.runtime.transfer.TransferHandler;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.transfer.message.DataMessage;
import io.ray.streaming.util.Config;
import java.lang.management.ManagementFactory;
import java.nio.ByteBuffer;
@@ -104,7 +104,7 @@ class ReaderWorker extends Worker {
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B"));
StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf);
dataReader = new DataReader(inputQueueList, inputActors, new HashMap<>(), workerConfig);
dataReader = new DataReader(inputQueueList, inputActors, workerConfig);
// Should not GetBundle in RayCall thread
Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() {
@@ -124,7 +124,7 @@ class ReaderWorker extends Worker {
int checkPointId = 1;
for (int i = 0; i < msgCount * inputQueueList.size(); ++i) {
DataMessage dataMessage = (DataMessage) dataReader.read(100);
DataMessage dataMessage = dataReader.read(100);
if (dataMessage == null) {
LOGGER.error("dataMessage is null");
@@ -232,7 +232,7 @@ class WriterWorker extends Worker {
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"),
new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B"));
StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf);
dataWriter = new DataWriter(outputQueueList, outputActors, new HashMap<>(), workerConfig);
dataWriter = new DataWriter(outputQueueList, outputActors, workerConfig);
Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() {
@Override
public void run() {
@@ -4,7 +4,6 @@ import static org.testng.Assert.assertEquals;
import io.ray.streaming.runtime.BaseUnitTest;
import io.ray.streaming.runtime.transfer.channel.ChannelId;
import io.ray.streaming.runtime.util.EnvUtil;
import org.testng.annotations.Test;
@@ -3,4 +3,4 @@ log4j.rootLogger=INFO, stdout
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.Target=System.out
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SS} %-4p %c{1}:%L [%t] - %m%n
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n