getAllActorsId() {
+ return getAllActors().stream()
+ .map(BaseActorHandle::getId)
+ .collect(Collectors.toList());
+ }
+
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java
index 6ab2fd911..6aa7936b2 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java
@@ -3,11 +3,12 @@ package io.ray.streaming.runtime.core.graph.executiongraph;
import com.google.common.base.MoreObjects;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.jobgraph.JobEdge;
+import java.io.Serializable;
/**
* An edge that connects two execution job vertices.
*/
-public class ExecutionJobEdge {
+public class ExecutionJobEdge implements Serializable {
/**
* The source(upstream) execution job vertex.
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java
index f0c87bd0f..b617cc053 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java
@@ -8,6 +8,7 @@ import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.jobgraph.VertexType;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.runtime.config.master.ResourceConfig;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -20,7 +21,7 @@ import org.aeonbits.owner.ConfigFactory;
* Execution job vertex is the physical form of {@link JobVertex} and
* every execution job vertex is corresponding to a group of {@link ExecutionVertex}.
*/
-public class ExecutionJobVertex {
+public class ExecutionJobVertex implements Serializable {
/**
* Unique id. Use {@link JobVertex}'s id directly.
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java
index 0135b35ed..5d6a2556c 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java
@@ -9,6 +9,7 @@ import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.runtime.config.master.ResourceConfig;
import io.ray.streaming.runtime.core.resource.ContainerId;
import io.ray.streaming.runtime.core.resource.ResourceType;
+import io.ray.streaming.runtime.transfer.channel.ChannelId;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -60,6 +61,8 @@ public class ExecutionVertex implements Serializable {
*/
private ContainerId containerId;
+ private String pid;
+
/**
* Worker actor handle.
*/
@@ -73,6 +76,14 @@ public class ExecutionVertex implements Serializable {
private List inputEdges = new ArrayList<>();
private List outputEdges = new ArrayList<>();
+ private transient List outputChannelIdList;
+ private transient List inputChannelIdList;
+
+ private transient List outputActorList;
+ private transient List inputActorList;
+ private Map exeVertexChannelMap;
+
+
public ExecutionVertex(
int globalIndex,
int index,
@@ -92,9 +103,7 @@ public class ExecutionVertex implements Serializable {
}
private Map genWorkerConfig(Map jobConfig) {
- Map workerConfig = new HashMap<>();
- workerConfig.putAll(jobConfig);
- return workerConfig;
+ return new HashMap<>(jobConfig);
}
public int getExecutionVertexId() {
@@ -161,14 +170,14 @@ public class ExecutionVertex implements Serializable {
return workerActor;
}
- public ActorId getWorkerActorId() {
- return workerActor.getId();
- }
-
public void setWorkerActor(BaseActorHandle workerActor) {
this.workerActor = workerActor;
}
+ public ActorId getWorkerActorId() {
+ return workerActor.getId();
+ }
+
public List getInputEdges() {
return inputEdges;
}
@@ -199,6 +208,14 @@ public class ExecutionVertex implements Serializable {
.collect(Collectors.toList());
}
+ public ActorId getActorId() {
+ return null == workerActor ? null : workerActor.getId();
+ }
+
+ public String getActorName() {
+ return String.valueOf(executionVertexId);
+ }
+
public Map getResource() {
return resource;
}
@@ -219,12 +236,89 @@ public class ExecutionVertex implements Serializable {
this.containerId = containerId;
}
+ public String getPid() {
+ return pid;
+ }
+
+ public void setPid(String pid) {
+ this.pid = pid;
+ }
+
public void setContainerIfNotExist(ContainerId containerId) {
if (null == this.containerId) {
this.containerId = containerId;
}
}
+ /*---------channel-actor relations---------*/
+ public List getOutputChannelIdList() {
+ if (outputChannelIdList == null) {
+ generateActorChannelInfo();
+ }
+ return outputChannelIdList;
+ }
+
+ public List getOutputActorList() {
+ if (outputActorList == null) {
+ generateActorChannelInfo();
+ }
+ return outputActorList;
+ }
+
+ public List getInputChannelIdList() {
+ if (inputChannelIdList == null) {
+ generateActorChannelInfo();
+ }
+ return inputChannelIdList;
+ }
+
+ public List getInputActorList() {
+ if (inputActorList == null) {
+ generateActorChannelInfo();
+ }
+ return inputActorList;
+ }
+
+
+ public String getChannelIdByPeerVertex(ExecutionVertex peerVertex) {
+ if (exeVertexChannelMap == null) {
+ generateActorChannelInfo();
+ }
+ return exeVertexChannelMap.get(peerVertex.getExecutionVertexId());
+ }
+
+
+ private void generateActorChannelInfo() {
+ inputChannelIdList = new ArrayList<>();
+ inputActorList = new ArrayList<>();
+ outputChannelIdList = new ArrayList<>();
+ outputActorList = new ArrayList<>();
+ exeVertexChannelMap = new HashMap<>();
+
+ List inputEdges = getInputEdges();
+ for (ExecutionEdge edge : inputEdges) {
+ String channelId = ChannelId.genIdStr(
+ edge.getSourceExecutionVertex().getExecutionVertexId(),
+ getExecutionVertexId(),
+ getBuildTime());
+ inputChannelIdList.add(channelId);
+ inputActorList.add(edge.getSourceExecutionVertex().getWorkerActor());
+ exeVertexChannelMap.put(edge.getSourceExecutionVertex().getExecutionVertexId(), channelId);
+ }
+
+ List outputEdges = getOutputEdges();
+ for (ExecutionEdge edge : outputEdges) {
+ String channelId = ChannelId.genIdStr(
+ getExecutionVertexId(),
+ edge.getTargetExecutionVertex().getExecutionVertexId(),
+ getBuildTime());
+ outputChannelIdList.add(channelId);
+ outputActorList.add(edge.getTargetExecutionVertex().getWorkerActor());
+ exeVertexChannelMap.put(edge.getTargetExecutionVertex().getExecutionVertexId(), channelId);
+ }
+ }
+
+
private Map generateResources(ResourceConfig resourceConfig) {
Map resourceMap = new HashMap<>();
if (resourceConfig.isTaskCpuResourceLimit()) {
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java
index 500721d3d..d189c42c1 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java
@@ -15,7 +15,7 @@ public class ProcessBuilder {
public static StreamProcessor buildProcessor(StreamOperator streamOperator) {
OperatorType type = streamOperator.getOpType();
LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type,
- streamOperator.getClass().getSimpleName().toString());
+ streamOperator.getClass().getSimpleName());
switch (type) {
case SOURCE:
return new SourceProcessor<>((SourceOperator) streamOperator);
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java
index 3b128376c..54fe76cd8 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java
@@ -2,6 +2,7 @@ package io.ray.streaming.runtime.core.processor;
import io.ray.streaming.api.collector.Collector;
import io.ray.streaming.api.context.RuntimeContext;
+import io.ray.streaming.api.function.Function;
import java.io.Serializable;
import java.util.List;
@@ -11,5 +12,15 @@ public interface Processor extends Serializable {
void process(T t);
+ /**
+ * See {@link Function#saveCheckpoint()}.
+ */
+ Serializable saveCheckpoint();
+
+ /**
+ * See {@link Function#loadCheckpoint(Serializable)}.
+ */
+ void loadCheckpoint(Serializable checkpointObject);
+
void close();
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java
index 020f39d16..1cc721a2a 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java
@@ -19,8 +19,8 @@ public class SourceProcessor extends StreamProcessor implements Processo
LOGGER.info("opened {}", this);
}
+ @Override
+ public Serializable saveCheckpoint() {
+ return operator.saveCheckpoint();
+ }
+
+ @Override
+ public void loadCheckpoint(Serializable checkpointObject) {
+ operator.loadCheckpoint(checkpointObject);
+ }
+
@Override
public String toString() {
return this.getClass().getSimpleName();
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java
index 60d3a0843..6115e4d50 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java
@@ -1,18 +1,36 @@
package io.ray.streaming.runtime.master;
import com.google.common.base.Preconditions;
+import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.ActorHandle;
+import io.ray.api.BaseActorHandle;
+import io.ray.api.Ray;
+import io.ray.api.id.ActorId;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.runtime.config.StreamingConfig;
import io.ray.streaming.runtime.config.StreamingMasterConfig;
+import io.ray.streaming.runtime.context.ContextBackend;
+import io.ray.streaming.runtime.context.ContextBackendFactory;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
+import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
+import io.ray.streaming.runtime.core.resource.Container;
+import io.ray.streaming.runtime.generated.RemoteCall;
+import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
+import io.ray.streaming.runtime.master.coordinator.CheckpointCoordinator;
+import io.ray.streaming.runtime.master.coordinator.FailoverCoordinator;
+import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
+import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
import io.ray.streaming.runtime.master.graphmanager.GraphManager;
import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl;
import io.ray.streaming.runtime.master.resourcemanager.ResourceManager;
import io.ray.streaming.runtime.master.resourcemanager.ResourceManagerImpl;
import io.ray.streaming.runtime.master.scheduler.JobSchedulerImpl;
+import io.ray.streaming.runtime.util.CheckpointStateUtil;
+import io.ray.streaming.runtime.util.ResourceUtil;
+import io.ray.streaming.runtime.util.Serializer;
import io.ray.streaming.runtime.worker.JobWorker;
import java.util.Map;
+import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -24,33 +42,68 @@ public class JobMaster {
private static final Logger LOG = LoggerFactory.getLogger(JobMaster.class);
- private JobRuntimeContext runtimeContext;
+ private JobMasterRuntimeContext runtimeContext;
private ResourceManager resourceManager;
private JobSchedulerImpl scheduler;
private GraphManager graphManager;
private StreamingMasterConfig conf;
+ private ContextBackend contextBackend;
+
private ActorHandle jobMasterActor;
+ // coordinators
+ private CheckpointCoordinator checkpointCoordinator;
+ private FailoverCoordinator failoverCoordinator;
+
public JobMaster(Map confMap) {
LOG.info("Creating job master with conf: {}.", confMap);
StreamingConfig streamingConfig = new StreamingConfig(confMap);
this.conf = streamingConfig.masterConfig;
+ this.contextBackend = ContextBackendFactory.getContextBackend(this.conf);
// init runtime context
- runtimeContext = new JobRuntimeContext(streamingConfig);
+ runtimeContext = new JobMasterRuntimeContext(streamingConfig);
+
+ // load checkpoint if is recover
+ if (Ray.getRuntimeContext().wasCurrentActorRestarted()) {
+ loadMasterCheckpoint();
+ }
LOG.info("Finished creating job master.");
}
+ public static String getJobMasterRuntimeContextKey(StreamingMasterConfig conf) {
+ return conf.checkpointConfig.jobMasterContextCpPrefixKey() + conf.commonConfig.jobName();
+ }
+
+ private void loadMasterCheckpoint() {
+ LOG.info("Start to load JobMaster's checkpoint.");
+ // recover runtime context
+ byte[] bytes =
+ CheckpointStateUtil.get(contextBackend, getJobMasterRuntimeContextKey(getConf()));
+ if (bytes == null) {
+ LOG.warn("JobMaster got empty checkpoint from state backend. Skip loading checkpoint.");
+ // cp 0 was automatically saved when job started, see StreamTask.
+ runtimeContext.checkpointIds.add(0L);
+ return;
+ }
+
+ this.runtimeContext = Serializer.decode(bytes);
+
+ // FO case, triggered by ray, we need to register context when loading checkpoint
+ LOG.info("JobMaster recover runtime context[{}] from state backend.", runtimeContext);
+ init(true);
+ }
+
/**
* Init JobMaster. To initiate or recover other components(like metrics and extra coordinators).
*
* @return init result
*/
- public Boolean init() {
- LOG.info("Initializing job master.");
+ public Boolean init(boolean isRecover) {
+ LOG.info("Initializing job master, isRecover={}.", isRecover);
if (this.runtimeContext.getExecutionGraph() == null) {
LOG.error("Init job master failed. Job graphs is null.");
@@ -60,6 +113,14 @@ public class JobMaster {
ExecutionGraph executionGraph = graphManager.getExecutionGraph();
Preconditions.checkArgument(executionGraph != null, "no execution graph");
+ // init coordinators
+ checkpointCoordinator = new CheckpointCoordinator(this);
+ checkpointCoordinator.start();
+ failoverCoordinator = new FailoverCoordinator(this, isRecover);
+ failoverCoordinator.start();
+
+ saveContext();
+
LOG.info("Finished initializing job master.");
return true;
}
@@ -101,11 +162,86 @@ public class JobMaster {
return true;
}
+ public synchronized void saveContext() {
+ if (runtimeContext != null && getConf() != null) {
+ LOG.debug("Save JobMaster context.");
+
+ byte[] contextBytes = Serializer.encode(runtimeContext);
+ CheckpointStateUtil
+ .put(contextBackend, getJobMasterRuntimeContextKey(getConf()), contextBytes);
+ }
+ }
+
+ public byte[] reportJobWorkerCommit(byte[] reportBytes) {
+ Boolean ret = false;
+ RemoteCall.BaseWorkerCmd reportPb;
+ try {
+ reportPb = RemoteCall.BaseWorkerCmd.parseFrom(reportBytes);
+ ActorId actorId = ActorId.fromBytes(reportPb.getActorId().toByteArray());
+ long remoteCallCost = System.currentTimeMillis() - reportPb.getTimestamp();
+ LOG.info("Vertex {}, request job worker commit cost {}ms, actorId={}.",
+ getExecutionVertex(actorId), remoteCallCost, actorId);
+ RemoteCall.WorkerCommitReport commit =
+ reportPb.getDetail().unpack(RemoteCall.WorkerCommitReport.class);
+ WorkerCommitReport report = new WorkerCommitReport(actorId, commit.getCommitCheckpointId());
+ ret = checkpointCoordinator.reportJobWorkerCommit(report);
+ } catch (InvalidProtocolBufferException e) {
+ LOG.error("Parse job worker commit has exception.", e);
+ }
+ return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray();
+ }
+
+ public byte[] requestJobWorkerRollback(byte[] requestBytes) {
+ Boolean ret = false;
+ RemoteCall.BaseWorkerCmd requestPb;
+ try {
+ requestPb = RemoteCall.BaseWorkerCmd.parseFrom(requestBytes);
+ ActorId actorId = ActorId.fromBytes(requestPb.getActorId().toByteArray());
+ long remoteCallCost = System.currentTimeMillis() - requestPb.getTimestamp();
+ ExecutionGraph executionGraph = graphManager.getExecutionGraph();
+ Optional rayActor = executionGraph.getActorById(actorId);
+ if (!rayActor.isPresent()) {
+ LOG.warn("Skip this invalid rollback, actor id {} is not found.", actorId);
+ return RemoteCall.BoolResult.newBuilder().setBoolRes(false).build().toByteArray();
+ }
+ ExecutionVertex exeVertex = getExecutionVertex(actorId);
+ LOG.info("Vertex {}, request job worker rollback cost {}ms, actorId={}.",
+ exeVertex, remoteCallCost, actorId);
+ RemoteCall.WorkerRollbackRequest rollbackPb
+ = RemoteCall.WorkerRollbackRequest.parseFrom(requestPb.getDetail().getValue());
+ exeVertex.setPid(rollbackPb.getWorkerPid());
+ // To find old container where slot is located in.
+ String hostname = "";
+ Optional container = ResourceUtil.getContainerById(
+ resourceManager.getRegisteredContainers(),
+ exeVertex.getContainerId()
+ );
+ if (container.isPresent()) {
+ hostname = container.get().getHostname();
+ }
+ WorkerRollbackRequest request = new WorkerRollbackRequest(
+ actorId, rollbackPb.getExceptionMsg(), hostname, exeVertex.getPid()
+ );
+
+ ret = failoverCoordinator.requestJobWorkerRollback(request);
+ LOG.info("Vertex {} request rollback, exception msg : {}.",
+ exeVertex, rollbackPb.getExceptionMsg());
+
+ } catch (Throwable e) {
+ LOG.error("Parse job worker rollback has exception.", e);
+ }
+ return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray();
+ }
+
+ private ExecutionVertex getExecutionVertex(ActorId id) {
+ return graphManager.getExecutionGraph().getExecutionVertexByActorId(id);
+ }
+
public ActorHandle getJobMasterActor() {
return jobMasterActor;
}
- public JobRuntimeContext getRuntimeContext() {
+ public JobMasterRuntimeContext getRuntimeContext() {
return runtimeContext;
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java
new file mode 100644
index 000000000..c9e6e8f57
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java
@@ -0,0 +1,81 @@
+package io.ray.streaming.runtime.master.context;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.collect.Sets;
+import io.ray.streaming.jobgraph.JobGraph;
+import io.ray.streaming.runtime.config.StreamingConfig;
+import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
+import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+/**
+ * Runtime context for job master, which will be stored in backend when saving checkpoint.
+ *
+ * Including: graph, resource, checkpoint info, etc.
+ */
+public class JobMasterRuntimeContext implements Serializable {
+
+ /*--------------Checkpoint----------------*/
+ public volatile List checkpointIds = new ArrayList<>();
+ public volatile long lastCheckpointId = 0;
+ public volatile long lastCpTimestamp = 0;
+ public volatile BlockingQueue cpCmds = new LinkedBlockingQueue<>();
+ /*--------------Failover----------------*/
+ public volatile BlockingQueue foCmds = new ArrayBlockingQueue<>(8192);
+ public volatile Set unfinishedFoCmds = Sets.newConcurrentHashSet();
+ private StreamingConfig conf;
+ private JobGraph jobGraph;
+ private volatile ExecutionGraph executionGraph;
+
+ public JobMasterRuntimeContext(StreamingConfig conf) {
+ this.conf = conf;
+ }
+
+ public String getJobName() {
+ return conf.masterConfig.commonConfig.jobName();
+ }
+
+ public StreamingConfig getConf() {
+ return conf;
+ }
+
+ public JobGraph getJobGraph() {
+ return jobGraph;
+ }
+
+ public void setJobGraph(JobGraph jobGraph) {
+ this.jobGraph = jobGraph;
+ }
+
+ public ExecutionGraph getExecutionGraph() {
+ return executionGraph;
+ }
+
+ public void setExecutionGraph(ExecutionGraph executionGraph) {
+ this.executionGraph = executionGraph;
+ }
+
+ public Long getLastValidCheckpointId() {
+ if (checkpointIds.isEmpty()) {
+ // OL is invalid checkpoint id, worker will pass it
+ return 0L;
+ }
+ return checkpointIds.get(checkpointIds.size() - 1);
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("jobGraph", jobGraph)
+ .add("executionGraph", executionGraph)
+ .add("conf", conf.getMap())
+ .toString();
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java
new file mode 100644
index 000000000..ece4de4b7
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java
@@ -0,0 +1,44 @@
+package io.ray.streaming.runtime.master.coordinator;
+
+import io.ray.api.Ray;
+import io.ray.streaming.runtime.master.JobMaster;
+import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
+import io.ray.streaming.runtime.master.graphmanager.GraphManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class BaseCoordinator implements Runnable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(BaseCoordinator.class);
+
+ protected final JobMaster jobMaster;
+
+ protected final JobMasterRuntimeContext runtimeContext;
+ protected final GraphManager graphManager;
+ protected volatile boolean closed;
+ private Thread thread;
+
+ public BaseCoordinator(JobMaster jobMaster) {
+ this.jobMaster = jobMaster;
+ this.runtimeContext = jobMaster.getRuntimeContext();
+ this.graphManager = jobMaster.getGraphManager();
+ }
+
+ public void start() {
+ thread = new Thread(Ray.wrapRunnable(this),
+ this.getClass().getName() + "-" + System.currentTimeMillis());
+ thread.start();
+ }
+
+ public void stop() {
+ closed = true;
+
+ try {
+ if (thread != null) {
+ thread.join(30000);
+ }
+ } catch (InterruptedException e) {
+ LOG.error("Coordinator thread exit has exception.", e);
+ }
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java
new file mode 100644
index 000000000..862528776
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java
@@ -0,0 +1,215 @@
+package io.ray.streaming.runtime.master.coordinator;
+
+import com.google.common.base.Preconditions;
+import io.ray.api.BaseActorHandle;
+import io.ray.api.ObjectRef;
+import io.ray.api.id.ActorId;
+import io.ray.runtime.exception.RayException;
+import io.ray.streaming.runtime.master.JobMaster;
+import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
+import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport;
+import io.ray.streaming.runtime.rpc.RemoteCallWorker;
+import io.ray.streaming.runtime.worker.JobWorker;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * CheckpointCoordinator is the controller of checkpoint, responsible for triggering checkpoint,
+ * collecting {@link JobWorker}'s reports and calling {@link JobWorker} to clear expired
+ * checkpoints when new checkpoint finished.
+ */
+public class CheckpointCoordinator extends BaseCoordinator {
+
+ private static final Logger LOG = LoggerFactory.getLogger(CheckpointCoordinator.class);
+ private final Set pendingCheckpointActors = new HashSet<>();
+ private final Set interruptedCheckpointSet = new HashSet<>();
+ private final int cpIntervalSecs;
+ private final int cpTimeoutSecs;
+
+ public CheckpointCoordinator(JobMaster jobMaster) {
+ super(jobMaster);
+
+ // get checkpoint interval from conf
+ this.cpIntervalSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpIntervalSecs();
+ this.cpTimeoutSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpTimeoutSecs();
+
+ // Trigger next checkpoint in interval by reset last checkpoint timestamp.
+ runtimeContext.lastCpTimestamp = System.currentTimeMillis();
+ }
+
+ @Override
+ public void run() {
+ while (!closed) {
+ try {
+ final BaseWorkerCmd command = runtimeContext.cpCmds.poll(1, TimeUnit.SECONDS);
+ if (command != null) {
+ if (command instanceof WorkerCommitReport) {
+ processCommitReport((WorkerCommitReport) command);
+ } else {
+ interruptCheckpoint();
+ }
+ }
+
+ if (!pendingCheckpointActors.isEmpty()) {
+ // if wait commit report timeout, this cp fail, and restart next cp
+ if (timeoutOnWaitCheckpoint()) {
+ LOG.warn("Waiting for checkpoint {} timeout, pending cp actors is {}.",
+ runtimeContext.lastCheckpointId,
+ graphManager.getExecutionGraph().getActorName(pendingCheckpointActors));
+
+ interruptCheckpoint();
+ }
+ } else {
+ maybeTriggerCheckpoint();
+ }
+ } catch (Throwable e) {
+ LOG.error("Checkpoint coordinator occur err.", e);
+ try {
+ interruptCheckpoint();
+ } catch (Throwable interruptE) {
+ LOG.error("Ignore interrupt checkpoint exception in catch block.");
+ }
+ }
+ }
+ LOG.warn("Checkpoint coordinator thread exit.");
+ }
+
+ public Boolean reportJobWorkerCommit(WorkerCommitReport report) {
+ LOG.info("Report job worker commit {}.", report);
+
+ Boolean ret = runtimeContext.cpCmds.offer(report);
+ if (!ret) {
+ LOG.warn("Report job worker commit failed, because command queue is full.");
+ }
+ return ret;
+ }
+
+ private void processCommitReport(WorkerCommitReport commitReport) {
+ LOG.info("Start process commit report {}, from actor name={}.", commitReport,
+ graphManager.getExecutionGraph().getActorName(commitReport.fromActorId));
+
+ try {
+ Preconditions.checkArgument(
+ commitReport.commitCheckpointId == runtimeContext.lastCheckpointId,
+ "expect checkpointId %s, but got %s",
+ runtimeContext.lastCheckpointId, commitReport);
+
+ if (!pendingCheckpointActors.contains(commitReport.fromActorId)) {
+ LOG.warn("Invalid commit report, skipped.");
+ return;
+ }
+
+ pendingCheckpointActors.remove(commitReport.fromActorId);
+ LOG.info("Pending actors after this commit: {}.",
+ graphManager.getExecutionGraph().getActorName(pendingCheckpointActors));
+
+ // checkpoint finish
+ if (pendingCheckpointActors.isEmpty()) {
+ // actor finish
+ runtimeContext.checkpointIds.add(runtimeContext.lastCheckpointId);
+
+ if (clearExpiredCpStateAndQueueMsg()) {
+ // save master context
+ jobMaster.saveContext();
+
+ LOG.info("Finish checkpoint: {}.", runtimeContext.lastCheckpointId);
+ } else {
+ LOG.warn("Fail to do checkpoint: {}.", runtimeContext.lastCheckpointId);
+ }
+ }
+
+ LOG.info("Process commit report {} success.", commitReport);
+ } catch (Throwable e) {
+ LOG.warn("Process commit report has exception.", e);
+ }
+ }
+
+ private void triggerCheckpoint() {
+ interruptedCheckpointSet.clear();
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Start trigger checkpoint {}.", runtimeContext.lastCheckpointId + 1);
+ }
+
+ List allIds = graphManager.getExecutionGraph().getAllActorsId();
+ // do the checkpoint
+ pendingCheckpointActors.addAll(allIds);
+
+ // inc last checkpoint id
+ ++runtimeContext.lastCheckpointId;
+
+ final List sourcesRet = new ArrayList<>();
+
+ graphManager.getExecutionGraph().getSourceActors().forEach(actor -> {
+ sourcesRet.add(RemoteCallWorker.triggerCheckpoint(
+ actor, runtimeContext.lastCheckpointId));
+ });
+
+ for (ObjectRef rayObject : sourcesRet) {
+ if (rayObject.get() instanceof RayException) {
+ LOG.warn("Trigger checkpoint has exception.", (RayException) rayObject.get());
+ throw (RayException) rayObject.get();
+ }
+ }
+ runtimeContext.lastCpTimestamp = System.currentTimeMillis();
+ LOG.info("Trigger checkpoint success.");
+ }
+
+ private void interruptCheckpoint() {
+ // notify checkpoint timeout is time-consuming while many workers crash or
+ // container failover.
+ if (interruptedCheckpointSet.contains(runtimeContext.lastCheckpointId)) {
+ LOG.warn("Skip interrupt duplicated checkpoint id : {}.", runtimeContext.lastCheckpointId);
+ return;
+ }
+ interruptedCheckpointSet.add(runtimeContext.lastCheckpointId);
+ LOG.warn("Interrupt checkpoint, checkpoint id : {}.", runtimeContext.lastCheckpointId);
+
+ List allActor = graphManager.getExecutionGraph().getAllActors();
+ if (runtimeContext.lastCheckpointId > runtimeContext.getLastValidCheckpointId()) {
+ RemoteCallWorker
+ .notifyCheckpointTimeoutParallel(allActor, runtimeContext.lastCheckpointId);
+ }
+
+ if (!pendingCheckpointActors.isEmpty()) {
+ pendingCheckpointActors.clear();
+ }
+ maybeTriggerCheckpoint();
+ }
+
+ private void maybeTriggerCheckpoint() {
+ if (readyToTrigger()) {
+ triggerCheckpoint();
+ }
+ }
+
+ private boolean clearExpiredCpStateAndQueueMsg() {
+ // queue msg must clear when first checkpoint finish
+ List allActor = graphManager.getExecutionGraph().getAllActors();
+ if (1 == runtimeContext.checkpointIds.size()) {
+ Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0);
+ RemoteCallWorker.clearExpiredCheckpointParallel(allActor, 0L, msgExpiredCheckpointId);
+ }
+
+ if (runtimeContext.checkpointIds.size() > 1) {
+ Long stateExpiredCpId = runtimeContext.checkpointIds.remove(0);
+ Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0);
+ RemoteCallWorker
+ .clearExpiredCheckpointParallel(allActor, stateExpiredCpId, msgExpiredCheckpointId);
+ }
+ return true;
+ }
+
+ private boolean readyToTrigger() {
+ return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >=
+ cpIntervalSecs * 1000;
+ }
+
+ private boolean timeoutOnWaitCheckpoint() {
+ return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpTimeoutSecs * 1000;
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java
new file mode 100644
index 000000000..c58c84d6a
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java
@@ -0,0 +1,281 @@
+package io.ray.streaming.runtime.master.coordinator;
+
+import io.ray.api.BaseActorHandle;
+import io.ray.api.id.ActorId;
+import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
+import io.ray.streaming.runtime.core.resource.Container;
+import io.ray.streaming.runtime.master.JobMaster;
+import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
+import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd;
+import io.ray.streaming.runtime.master.coordinator.command.InterruptCheckpointRequest;
+import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest;
+import io.ray.streaming.runtime.rpc.async.AsyncRemoteCaller;
+import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo;
+import io.ray.streaming.runtime.util.ResourceUtil;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import org.apache.commons.collections.map.DefaultedMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class FailoverCoordinator extends BaseCoordinator {
+
+ private static final Logger LOG = LoggerFactory.getLogger(FailoverCoordinator.class);
+
+ private static final int ROLLBACK_RETRY_TIME_MS = 10 * 1000;
+ private final Object cmdLock = new Object();
+ private final AsyncRemoteCaller asyncRemoteCaller;
+ private long currentCascadingGroupId = 0;
+ private final Map isRollbacking =
+ DefaultedMap.decorate(new ConcurrentHashMap(), false);
+
+ public FailoverCoordinator(JobMaster jobMaster, boolean isRecover) {
+ this(jobMaster, new AsyncRemoteCaller(), isRecover);
+ }
+
+ public FailoverCoordinator(
+ JobMaster jobMaster, AsyncRemoteCaller asyncRemoteCaller,
+ boolean isRecover) {
+ super(jobMaster);
+
+ this.asyncRemoteCaller = asyncRemoteCaller;
+ // recover unfinished FO commands
+ JobMasterRuntimeContext runtimeContext = jobMaster.getRuntimeContext();
+ if (isRecover) {
+ runtimeContext.foCmds.addAll(runtimeContext.unfinishedFoCmds);
+ }
+ runtimeContext.unfinishedFoCmds.clear();
+ }
+
+ @Override
+ public void run() {
+ while (!closed) {
+ try {
+ final BaseWorkerCmd command;
+ // see rollback() for lock reason
+ synchronized (cmdLock) {
+ command = jobMaster.getRuntimeContext().foCmds.poll(1, TimeUnit.SECONDS);
+ }
+ if (null == command) {
+ continue;
+ }
+ if (command instanceof WorkerRollbackRequest) {
+ jobMaster.getRuntimeContext().unfinishedFoCmds.add(command);
+ dealWithRollbackRequest((WorkerRollbackRequest) command);
+ }
+ } catch (Throwable e) {
+ LOG.error("Fo coordinator occur err.", e);
+ }
+ }
+ LOG.warn("Fo coordinator thread exit.");
+ }
+
+ private Boolean isDuplicateRequest(WorkerRollbackRequest request) {
+ try {
+ Object[] foCmdsArray = runtimeContext.foCmds.toArray();
+ for (Object cmd : foCmdsArray) {
+ if (request.fromActorId.equals(((BaseWorkerCmd) cmd).fromActorId)) {
+ return true;
+ }
+ }
+ } catch (Exception e) {
+ LOG.warn("Check request is duplicated failed.", e);
+ }
+ return false;
+ }
+
+ public Boolean requestJobWorkerRollback(WorkerRollbackRequest request) {
+ LOG.info("Request job worker rollback {}.", request);
+ boolean ret;
+ if (!isDuplicateRequest(request)) {
+ ret = runtimeContext.foCmds.offer(request);
+ } else {
+ LOG.warn("Skip duplicated worker rollback request, {}.", request.toString());
+ return true;
+ }
+ jobMaster.saveContext();
+ if (!ret) {
+ LOG.warn("Request job worker rollback failed, because command queue is full.");
+ }
+ return ret;
+ }
+
+ private void dealWithRollbackRequest(WorkerRollbackRequest rollbackRequest) {
+ LOG.info("Start deal with rollback request {}.", rollbackRequest);
+
+ ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest);
+
+ // Reset pid for new-rollback actor.
+ if (null != rollbackRequest.getPid() &&
+ !rollbackRequest.getPid().equals(WorkerRollbackRequest.DEFAULT_PID)) {
+ exeVertex.setPid(rollbackRequest.getPid());
+ }
+
+ if (isRollbacking.get(exeVertex)) {
+ LOG.info("Vertex {} is rollbacking, skip rollback again.", exeVertex);
+ return;
+ }
+
+ String hostname = "";
+ Optional container = ResourceUtil.getContainerById(
+ jobMaster.getResourceManager().getRegisteredContainers(),
+ exeVertex.getContainerId()
+ );
+ if (container.isPresent()) {
+ hostname = container.get().getHostname();
+ }
+
+ if (rollbackRequest.isForcedRollback) {
+ interruptCheckpointAndRollback(rollbackRequest);
+ } else {
+ asyncRemoteCaller.checkIfNeedRollbackAsync(exeVertex.getWorkerActor(), res -> {
+ if (!res) {
+ LOG.info("Vertex {} doesn't need to rollback, skip it.", exeVertex);
+ return;
+ }
+ interruptCheckpointAndRollback(rollbackRequest);
+ }, throwable -> {
+ LOG.error("Exception when calling checkIfNeedRollbackAsync, maybe vertex is dead" +
+ ", ignore this request, vertex={}.", exeVertex, throwable);
+ });
+ }
+
+ LOG.info("Deal with rollback request {} success.", rollbackRequest);
+ }
+
+ private void interruptCheckpointAndRollback(WorkerRollbackRequest rollbackRequest) {
+ // assign a cascadingGroupId
+ if (rollbackRequest.cascadingGroupId == null) {
+ rollbackRequest.cascadingGroupId = currentCascadingGroupId++;
+ }
+ // get last valid checkpoint id then call worker rollback
+ rollback(jobMaster.getRuntimeContext().getLastValidCheckpointId(), rollbackRequest,
+ currentCascadingGroupId);
+ // we interrupt current checkpoint for 2 considerations:
+ // 1. current checkpoint might be timeout, because barrier might be lost after failover. so we
+ // interrupt current checkpoint to avoid waiting.
+ // 2. when we want to rollback vertex to n, job finished checkpoint n+1 and cleared state
+ // of checkpoint n.
+ jobMaster.getRuntimeContext().cpCmds.offer(new InterruptCheckpointRequest());
+ }
+
+ /**
+ * call worker rollback, and deal with it's reports. callback won't be finished until
+ * the entire DAG back to normal.
+ *
+ * @param checkpointId checkpointId to be rollback
+ * @param rollbackRequest worker rollback request
+ * @param cascadingGroupId all rollback of a cascading group should have same ID
+ */
+ private void rollback(
+ long checkpointId, WorkerRollbackRequest rollbackRequest,
+ long cascadingGroupId) {
+ ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest);
+ LOG.info("Call vertex {} to rollback, checkpoint id is {}, cascadingGroupId={}.",
+ exeVertex, checkpointId, cascadingGroupId);
+
+ isRollbacking.put(exeVertex, true);
+
+ asyncRemoteCaller.rollback(exeVertex.getWorkerActor(), checkpointId, result -> {
+ List newRollbackRequests = new ArrayList<>();
+ switch (result.getResultEnum()) {
+ case SUCCESS:
+ ChannelRecoverInfo recoverInfo = result.getResultObj();
+ LOG.info("Vertex {} rollback done, dataLostQueues={}, msg={}, cascadingGroupId={}.",
+ exeVertex, recoverInfo.getDataLostQueues(), result.getResultMsg(), cascadingGroupId);
+ // rollback upstream if vertex reports abnormal input queues
+ newRollbackRequests =
+ cascadeUpstreamActors(recoverInfo.getDataLostQueues(), exeVertex, cascadingGroupId);
+ break;
+ case SKIPPED:
+ LOG.info("Vertex skip rollback, result = {}, cascadingGroupId={}.", result,
+ cascadingGroupId);
+ break;
+ default:
+ LOG.error(
+ "Rollback vertex {} failed, result={}, cascadingGroupId={}," +
+ " rollback this worker again after {} ms.",
+ exeVertex, result, cascadingGroupId, ROLLBACK_RETRY_TIME_MS);
+ Thread.sleep(ROLLBACK_RETRY_TIME_MS);
+ LOG.info("Add rollback request for {} again, cascadingGroupId={}.", exeVertex,
+ cascadingGroupId);
+ newRollbackRequests.add(
+ new WorkerRollbackRequest(exeVertex, "", "Rollback failed, try again.", false)
+ );
+ break;
+ }
+
+ // lock to avoid executing new rollback requests added.
+ // consider such a case: A->B->C, C cascade B, and B cascade A
+ // if B is rollback before B's rollback request is saved, and then JobMaster crashed,
+ // then A will never be rollback.
+ synchronized (cmdLock) {
+ jobMaster.getRuntimeContext().foCmds.addAll(newRollbackRequests);
+ // this rollback request is finished, remove it.
+ jobMaster.getRuntimeContext().unfinishedFoCmds.remove(rollbackRequest);
+ jobMaster.saveContext();
+ }
+ isRollbacking.put(exeVertex, false);
+ }, throwable -> {
+ LOG.error("Exception when calling vertex to rollback, vertex={}.", exeVertex, throwable);
+ isRollbacking.put(exeVertex, false);
+ });
+
+ LOG.info("Finish rollback vertex {}, checkpoint id is {}.", exeVertex, checkpointId);
+ }
+
+ private List cascadeUpstreamActors(
+ Set dataLostQueues, ExecutionVertex fromVertex, long cascadingGroupId) {
+ List cascadedRollbackRequest = new ArrayList<>();
+ // rollback upstream if vertex reports abnormal input queues
+ dataLostQueues.forEach(q -> {
+ BaseActorHandle upstreamActor =
+ graphManager.getExecutionGraph().getPeerActor(fromVertex.getWorkerActor(), q);
+ ExecutionVertex upstreamExeVertex = getExecutionVertex(upstreamActor);
+ // vertexes that has already cascaded by other vertex in the same level
+ // of graph should be ignored.
+ if (isRollbacking.get(upstreamExeVertex)) {
+ return;
+ }
+ LOG.info("Call upstream vertex {} of vertex {} to rollback, cascadingGroupId={}.",
+ upstreamExeVertex, fromVertex, cascadingGroupId);
+ String hostname = "";
+ Optional container = ResourceUtil.getContainerById(
+ jobMaster.getResourceManager().getRegisteredContainers(),
+ upstreamExeVertex.getContainerId()
+ );
+ if (container.isPresent()) {
+ hostname = container.get().getHostname();
+ }
+ // force upstream vertexes to rollback
+ WorkerRollbackRequest upstreamRequest = new WorkerRollbackRequest(
+ upstreamExeVertex, hostname, String.format("Cascading rollback from %s", fromVertex), true
+ );
+ upstreamRequest.cascadingGroupId = cascadingGroupId;
+ cascadedRollbackRequest.add(upstreamRequest);
+ });
+ return cascadedRollbackRequest;
+ }
+
+ private ExecutionVertex getExeVertexFromRequest(WorkerRollbackRequest rollbackRequest) {
+ ActorId actorId = rollbackRequest.fromActorId;
+ Optional rayActor = graphManager.getExecutionGraph().getActorById(actorId);
+ if (!rayActor.isPresent()) {
+ throw new RuntimeException("Can not find ray actor of ID " + actorId);
+ }
+ return getExecutionVertex(rollbackRequest.fromActorId);
+ }
+
+ private ExecutionVertex getExecutionVertex(BaseActorHandle actor) {
+ return graphManager.getExecutionGraph().getExecutionVertexByActorId(actor.getId());
+ }
+
+ private ExecutionVertex getExecutionVertex(ActorId actorId) {
+ return graphManager.getExecutionGraph().getExecutionVertexByActorId(actorId);
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java
new file mode 100644
index 000000000..2c6a9322d
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java
@@ -0,0 +1,17 @@
+package io.ray.streaming.runtime.master.coordinator.command;
+
+import io.ray.api.id.ActorId;
+import java.io.Serializable;
+
+public abstract class BaseWorkerCmd implements Serializable {
+
+ public ActorId fromActorId;
+
+ public BaseWorkerCmd() {
+ }
+
+ protected BaseWorkerCmd(ActorId actorId) {
+ this.fromActorId = actorId;
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java
new file mode 100644
index 000000000..29a46ab10
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java
@@ -0,0 +1,5 @@
+package io.ray.streaming.runtime.master.coordinator.command;
+
+public final class InterruptCheckpointRequest extends BaseWorkerCmd {
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java
new file mode 100644
index 000000000..7750ce1b0
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java
@@ -0,0 +1,22 @@
+package io.ray.streaming.runtime.master.coordinator.command;
+
+import com.google.common.base.MoreObjects;
+import io.ray.api.id.ActorId;
+
+public final class WorkerCommitReport extends BaseWorkerCmd {
+
+ public final long commitCheckpointId;
+
+ public WorkerCommitReport(ActorId actorId, long commitCheckpointId) {
+ super(actorId);
+ this.commitCheckpointId = commitCheckpointId;
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("commitCheckpointId", commitCheckpointId)
+ .add("fromActorId", fromActorId)
+ .toString();
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java
new file mode 100644
index 000000000..e56518382
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java
@@ -0,0 +1,63 @@
+package io.ray.streaming.runtime.master.coordinator.command;
+
+import com.google.common.base.MoreObjects;
+import io.ray.api.id.ActorId;
+import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
+
+public final class WorkerRollbackRequest extends BaseWorkerCmd {
+
+ public static String DEFAULT_PID = "UNKNOWN_PID";
+ public Long cascadingGroupId = null;
+ public boolean isForcedRollback = false;
+ private String exceptionMsg = "No detail message.";
+ private String hostname = "UNKNOWN_HOST";
+ private String pid = DEFAULT_PID;
+
+ public WorkerRollbackRequest(ActorId actorId) {
+ super(actorId);
+ }
+
+ public WorkerRollbackRequest(ActorId actorId, String msg) {
+ super(actorId);
+ exceptionMsg = msg;
+ }
+
+ public WorkerRollbackRequest(
+ ExecutionVertex executionVertex,
+ String hostname,
+ String msg,
+ boolean isForcedRollback) {
+
+ super(executionVertex.getWorkerActorId());
+
+ this.hostname = hostname;
+ this.pid = executionVertex.getPid();
+ this.exceptionMsg = msg;
+ this.isForcedRollback = isForcedRollback;
+ }
+
+ public WorkerRollbackRequest(ActorId actorId, String msg, String hostname, String pid) {
+ this(actorId, msg);
+ this.hostname = hostname;
+ this.pid = pid;
+ }
+
+ public String getRollbackExceptionMsg() {
+ return exceptionMsg;
+ }
+
+ public String getHostname() {
+ return hostname;
+ }
+
+ public String getPid() {
+ return pid;
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("fromActorId", fromActorId)
+ .toString();
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java
index e76963a47..a977967ff 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java
@@ -1,14 +1,19 @@
package io.ray.streaming.runtime.master.graphmanager;
+import io.ray.api.BaseActorHandle;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobEdge;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex;
-import io.ray.streaming.runtime.master.JobRuntimeContext;
+import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
+import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
+import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -16,9 +21,9 @@ public class GraphManagerImpl implements GraphManager {
private static final Logger LOG = LoggerFactory.getLogger(GraphManagerImpl.class);
- protected final JobRuntimeContext runtimeContext;
+ protected final JobMasterRuntimeContext runtimeContext;
- public GraphManagerImpl(JobRuntimeContext runtimeContext) {
+ public GraphManagerImpl(JobMasterRuntimeContext runtimeContext) {
this.runtimeContext = runtimeContext;
}
@@ -48,6 +53,7 @@ public class GraphManagerImpl implements GraphManager {
// create vertex
Map exeJobVertexMap = new LinkedHashMap<>();
+ Map executionVertexMap = new HashMap<>();
long buildTime = executionGraph.getBuildTime();
for (JobVertex jobVertex : jobGraph.getJobVertices()) {
int jobVertexId = jobVertex.getVertexId();
@@ -59,32 +65,47 @@ public class GraphManagerImpl implements GraphManager {
buildTime));
}
- // connect vertex
+ // for each job edge, connect all source exeVertices and target exeVertices
jobGraph.getJobEdges().forEach(jobEdge -> {
ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId());
ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId());
- ExecutionJobEdge executionJobEdge =
- new ExecutionJobEdge(source, target, jobEdge);
+ ExecutionJobEdge executionJobEdge = new ExecutionJobEdge(source, target, jobEdge);
source.getOutputEdges().add(executionJobEdge);
target.getInputEdges().add(executionJobEdge);
- source.getExecutionVertices().forEach(vertex -> {
- target.getExecutionVertices().forEach(outputVertex -> {
- ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge);
- vertex.getOutputEdges().add(executionEdge);
- outputVertex.getInputEdges().add(executionEdge);
+ source.getExecutionVertices().forEach(sourceExeVertex -> {
+ target.getExecutionVertices().forEach(targetExeVertex -> {
+ // pre-process some mappings
+ executionVertexMap.put(targetExeVertex.getExecutionVertexId(), targetExeVertex);
+ executionVertexMap.put(sourceExeVertex.getExecutionVertexId(), sourceExeVertex);
+ // build execution edge
+ ExecutionEdge executionEdge =
+ new ExecutionEdge(sourceExeVertex, targetExeVertex, executionJobEdge);
+ sourceExeVertex.getOutputEdges().add(executionEdge);
+ targetExeVertex.getInputEdges().add(executionEdge);
});
});
});
// set execution job vertex into execution graph
executionGraph.setExecutionJobVertexMap(exeJobVertexMap);
+ executionGraph.setExecutionVertexMap(executionVertexMap);
return executionGraph;
}
+ private void addActorToChannelGroupedActors(
+ Map> channelGroupedActors,
+ String channelId,
+ BaseActorHandle actor) {
+
+ Set actorSet =
+ channelGroupedActors.computeIfAbsent(channelId, k -> new HashSet<>());
+ actorSet.add(actor);
+ }
+
@Override
public JobGraph getJobGraph() {
return runtimeContext.getJobGraph();
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java
index 3b7b35ba6..2e59fed09 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java
@@ -11,7 +11,7 @@ import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.resource.Container;
import io.ray.streaming.runtime.core.resource.Resources;
-import io.ray.streaming.runtime.master.JobRuntimeContext;
+import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext;
import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy;
import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategyFactory;
import io.ray.streaming.runtime.util.RayUtils;
@@ -30,39 +30,33 @@ public class ResourceManagerImpl implements ResourceManager {
//Container used tag
private static final String CONTAINER_ENGAGED_KEY = "CONTAINER_ENGAGED_KEY";
-
- /**
- * Job runtime context.
- */
- private JobRuntimeContext runtimeContext;
-
- /**
- * Resource related configuration.
- */
- private ResourceConfig resourceConfig;
-
- /**
- * Slot assign strategy.
- */
- private ResourceAssignStrategy resourceAssignStrategy;
-
/**
* Resource description information.
*/
private final Resources resources;
-
- /**
- * Customized actor number for each container
- */
- private int actorNumPerContainer;
-
/**
* Timing resource updating thread
*/
private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1,
new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build());
+ /**
+ * Job runtime context.
+ */
+ private JobMasterRuntimeContext runtimeContext;
+ /**
+ * Resource related configuration.
+ */
+ private ResourceConfig resourceConfig;
+ /**
+ * Slot assign strategy.
+ */
+ private ResourceAssignStrategy resourceAssignStrategy;
+ /**
+ * Customized actor number for each container
+ */
+ private int actorNumPerContainer;
- public ResourceManagerImpl(JobRuntimeContext runtimeContext) {
+ public ResourceManagerImpl(JobMasterRuntimeContext runtimeContext) {
this.runtimeContext = runtimeContext;
StreamingMasterConfig masterConfig = runtimeContext.getConf().masterConfig;
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java
index 6b9b3a690..238fdf6f7 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java
@@ -23,20 +23,18 @@ import org.slf4j.LoggerFactory;
public class JobSchedulerImpl implements JobScheduler {
private static final Logger LOG = LoggerFactory.getLogger(JobSchedulerImpl.class);
-
- private StreamingConfig jobConf;
-
private final JobMaster jobMaster;
private final ResourceManager resourceManager;
private final GraphManager graphManager;
private final WorkerLifecycleController workerLifecycleController;
+ private StreamingConfig jobConfig;
public JobSchedulerImpl(JobMaster jobMaster) {
this.jobMaster = jobMaster;
this.graphManager = jobMaster.getGraphManager();
this.resourceManager = jobMaster.getResourceManager();
this.workerLifecycleController = new WorkerLifecycleController();
- this.jobConf = jobMaster.getRuntimeContext().getConf();
+ this.jobConfig = jobMaster.getRuntimeContext().getConf();
LOG.info("Scheduler initiated.");
}
@@ -46,8 +44,13 @@ public class JobSchedulerImpl implements JobScheduler {
LOG.info("Begin scheduling. Job: {}.", executionGraph.getJobName());
// Allocate resource then create workers
+ // Actor creation is in this step
prepareResourceAndCreateWorker(executionGraph);
+ // now actor info is available in execution graph
+ // preprocess some handy mappings in execution graph
+ executionGraph.generateActorMappings();
+
// init worker context and start to run
initAndStart(executionGraph);
@@ -87,7 +90,7 @@ public class JobSchedulerImpl implements JobScheduler {
initMaster();
// start workers
- startWorkers(executionGraph);
+ startWorkers(executionGraph, jobMaster.getRuntimeContext().lastCheckpointId);
}
/**
@@ -122,7 +125,7 @@ public class JobSchedulerImpl implements JobScheduler {
boolean result;
try {
result = workerLifecycleController.initWorkers(vertexToContextMap,
- jobConf.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs());
+ jobConfig.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs());
} catch (Exception e) {
LOG.error("Failed to initiate workers.", e);
return false;
@@ -133,11 +136,12 @@ public class JobSchedulerImpl implements JobScheduler {
/**
* Start JobWorkers according to the physical plan.
*/
- public boolean startWorkers(ExecutionGraph executionGraph) {
+ public boolean startWorkers(ExecutionGraph executionGraph, long checkpointId) {
boolean result;
try {
result = workerLifecycleController.startWorkers(
- executionGraph, jobConf.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs());
+ executionGraph, checkpointId,
+ jobConfig.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs());
} catch (Exception e) {
LOG.error("Failed to start workers.", e);
return false;
@@ -194,7 +198,7 @@ public class JobSchedulerImpl implements JobScheduler {
}
private void initMaster() {
- jobMaster.init();
+ jobMaster.init(false);
}
}
diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java
index 876e9f924..bc8b462c7 100644
--- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java
+++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java
@@ -9,6 +9,8 @@ import io.ray.api.id.ActorId;
import io.ray.streaming.api.Language;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph;
import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex;
+import io.ray.streaming.runtime.generated.RemoteCall;
+import io.ray.streaming.runtime.python.GraphPbBuilder;
import io.ray.streaming.runtime.rpc.RemoteCallWorker;
import io.ray.streaming.runtime.worker.JobWorker;
import io.ray.streaming.runtime.worker.context.JobWorkerContext;
@@ -40,20 +42,23 @@ public class WorkerLifecycleController {
* @return creation result
*/
private boolean createWorker(ExecutionVertex executionVertex) {
- LOG.info("Start to create worker actor for vertex: {} with resource: {}.",
- executionVertex.getExecutionVertexName(), executionVertex.getResource());
+ LOG.info("Start to create worker actor for vertex: {} with resource: {}, workeConfig: {}.",
+ executionVertex.getExecutionVertexName(), executionVertex.getResource(),
+ executionVertex.getWorkerConfig());
Language language = executionVertex.getLanguage();
BaseActorHandle actor;
if (Language.JAVA == language) {
- actor = Ray.actor(JobWorker::new)
+ actor = Ray.actor(JobWorker::new, executionVertex)
.setResources(executionVertex.getResource())
.setMaxRestarts(-1)
.remote();
} else {
+ RemoteCall.ExecutionVertexContext.ExecutionVertex vertexPb
+ = new GraphPbBuilder().buildVertex(executionVertex);
actor = Ray.actor(
- PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"))
+ PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"), vertexPb.toByteArray())
.setResources(executionVertex.getResource())
.setMaxRestarts(-1)
.remote();
@@ -111,20 +116,20 @@ public class WorkerLifecycleController {
* @param timeout timeout for waiting, unit: ms
* @return starting result
*/
- public boolean startWorkers(ExecutionGraph executionGraph, int timeout) {
+ public boolean startWorkers(ExecutionGraph executionGraph, long lastCheckpointId, int timeout) {
LOG.info("Begin starting workers.");
long startTime = System.currentTimeMillis();
- List> objectRefs = new ArrayList<>();
+ List> objectRefs = new ArrayList<>();
// start source actors 1st
executionGraph.getSourceActors()
- .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor)));
+ .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId)));
// then start non-source actors
executionGraph.getNonSourceActors()
- .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor)));
+ .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId)));
- WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout);
+ WaitResult