Revert "[Streaming] Fault Tolerance Implementation (#10008)" (#10582)

This reverts commit 1b1466748f.
This commit is contained in:
SangBin Cho
2020-09-04 13:21:18 -07:00
committed by GitHub
parent da83bbd764
commit cb919c5e5c
158 changed files with 1227 additions and 7040 deletions
-30
View File
@@ -1,30 +0,0 @@
class BaseWorkerCmd:
"""
base worker cmd
"""
def __init__(self, actor_id):
self.from_actor_id = actor_id
class WorkerCommitReport(BaseWorkerCmd):
"""
worker commit report
"""
def __init__(self, actor_id, commit_checkpoint_id):
super().__init__(actor_id)
self.commit_checkpoint_id = commit_checkpoint_id
class WorkerRollbackRequest(BaseWorkerCmd):
"""
worker rollback request
"""
def __init__(self, actor_id, exception_msg):
super().__init__(actor_id)
self.__exception_msg = exception_msg
def exception_msg(self):
return self.__exception_msg
-117
View File
@@ -1,117 +0,0 @@
import logging
import os
from abc import ABC, abstractmethod
from os import path
from ray.streaming.config import ConfigHelper, Config
logger = logging.getLogger(__name__)
class ContextBackend(ABC):
@abstractmethod
def get(self, key):
pass
@abstractmethod
def put(self, key, value):
pass
@abstractmethod
def remove(self, key):
pass
class MemoryContextBackend(ContextBackend):
def __init__(self, conf):
self.__dic = dict()
def get(self, key):
return self.__dic.get(key)
def put(self, key, value):
self.__dic[key] = value
def remove(self, key):
if key in self.__dic:
del self.__dic[key]
class LocalFileContextBackend(ContextBackend):
def __init__(self, conf):
self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf)
logger.info("Start init local file state backend, root_dir={}.".format(
self.__dir))
try:
os.mkdir(self.__dir)
except FileExistsError:
logger.info("dir already exists, skipped.")
def put(self, key, value):
logger.info("Put value of key {} start.".format(key))
with open(self.__gen_file_path(key), "wb") as f:
f.write(value)
def get(self, key):
logger.info("Get value of key {} start.".format(key))
full_path = self.__gen_file_path(key)
if not os.path.isfile(full_path):
return None
with open(full_path, "rb") as f:
return f.read()
def remove(self, key):
logger.info("Remove value of key {} start.".format(key))
try:
os.remove(self.__gen_file_path(key))
except Exception:
# ignore exception
pass
def rename(self, src, dst):
logger.info("rename {} to {}".format(src, dst))
os.rename(self.__gen_file_path(src), self.__gen_file_path(dst))
def exists(self, key) -> bool:
return os.path.exists(key)
def __gen_file_path(self, key):
return path.join(self.__dir, key)
class AtomicFsContextBackend(LocalFileContextBackend):
def __init__(self, conf):
super().__init__(conf)
self.__tmp_flag = "_tmp"
def put(self, key, value):
tmp_key = key + self.__tmp_flag
if super().exists(tmp_key) and not super().exists(key):
super().rename(tmp_key, key)
super().put(tmp_key, value)
super().remove(key)
super().rename(tmp_key, key)
def get(self, key):
tmp_key = key + self.__tmp_flag
if super().exists(tmp_key) and not super().exists(key):
return super().get(tmp_key)
return super().get(key)
def remove(self, key):
tmp_key = key + self.__tmp_flag
if super().exists(tmp_key):
super().remove(tmp_key)
super().remove(key)
class ContextBackendFactory:
@staticmethod
def get_context_backend(worker_config) -> ContextBackend:
backend_type = ConfigHelper.get_cp_context_backend_type(worker_config)
context_backend = None
if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE:
context_backend = AtomicFsContextBackend(worker_config)
elif backend_type == Config.CP_STATE_BACKEND_MEMORY:
context_backend = MemoryContextBackend(worker_config)
return context_backend
-30
View File
@@ -1,30 +0,0 @@
class Barrier:
"""
barrier
"""
def __init__(self, id):
self.id = id
def __str__(self):
return "Barrier [id:%s]" % self.id
class OpCheckpointInfo:
"""
operator checkpoint info
"""
def __init__(self,
operator_point=None,
input_points=None,
output_points=None,
checkpoint_id=None):
if input_points is None:
input_points = {}
if output_points is None:
output_points = {}
self.operator_point = operator_point
self.input_points = input_points
self.output_points = output_points
self.checkpoint_id = checkpoint_id
+42 -70
View File
@@ -5,9 +5,6 @@ import ray
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
import ray.streaming.operator as operator
import ray.streaming.partition as partition
from ray._raylet import ActorID
from ray.actor import ActorHandle
from ray.streaming.config import Config
from ray.streaming.generated.streaming_pb2 import Language
logger = logging.getLogger(__name__)
@@ -30,12 +27,10 @@ class NodeType(enum.Enum):
class ExecutionEdge:
def __init__(self, execution_edge_pb, language):
self.source_execution_vertex_id = execution_edge_pb \
.source_execution_vertex_id
self.target_execution_vertex_id = execution_edge_pb \
.target_execution_vertex_id
partition_bytes = execution_edge_pb.partition
def __init__(self, edge_pb, language):
self.source_execution_vertex_id = edge_pb.source_execution_vertex_id
self.target_execution_vertex_id = edge_pb.target_execution_vertex_id
partition_bytes = edge_pb.partition
# Sink node doesn't have partition function,
# so we only deserialize partition_bytes when it's not None or empty
if language == Language.PYTHON and partition_bytes:
@@ -43,73 +38,50 @@ class ExecutionEdge:
class ExecutionVertex:
worker_actor: ActorHandle
def __init__(self, execution_vertex_pb):
self.execution_vertex_id = execution_vertex_pb.execution_vertex_id
self.execution_job_vertex_id = execution_vertex_pb \
.execution_job_vertex_id
self.execution_job_vertex_name = execution_vertex_pb \
.execution_job_vertex_name
self.execution_vertex_index = execution_vertex_pb\
.execution_vertex_index
self.parallelism = execution_vertex_pb.parallelism
if execution_vertex_pb\
.language == Language.PYTHON:
# python operator descriptor
operator_bytes = execution_vertex_pb.operator
if execution_vertex_pb.chained:
def __init__(self, vertex_pb):
self.execution_vertex_id = vertex_pb.execution_vertex_id
self.execution_job_vertex_Id = vertex_pb.execution_job_vertex_Id
self.execution_job_vertex_name = vertex_pb.execution_job_vertex_name
self.execution_vertex_index = vertex_pb.execution_vertex_index
self.parallelism = vertex_pb.parallelism
if vertex_pb.language == Language.PYTHON:
operator_bytes = vertex_pb.operator # python operator descriptor
if vertex_pb.chained:
logger.info("Load chained operator")
self.stream_operator = operator.load_chained_operator(
operator_bytes)
else:
logger.info("Load operator")
self.stream_operator = operator.load_operator(operator_bytes)
self.worker_actor = None
if execution_vertex_pb.worker_actor:
self.worker_actor = ray.actor.ActorHandle. \
_deserialization_helper(execution_vertex_pb.worker_actor)
self.container_id = execution_vertex_pb.container_id
self.build_time = execution_vertex_pb.build_time
self.language = execution_vertex_pb.language
self.config = execution_vertex_pb.config
self.resource = execution_vertex_pb.resource
@property
def execution_vertex_name(self):
return "{}_{}_{}".format(self.execution_job_vertex_id,
self.execution_job_vertex_name,
self.execution_vertex_id)
self.worker_actor = ray.actor.ActorHandle. \
_deserialization_helper(vertex_pb.worker_actor)
self.container_id = vertex_pb.container_id
self.build_time = vertex_pb.build_time
self.language = vertex_pb.language
self.config = vertex_pb.config
self.resource = vertex_pb.resource
class ExecutionVertexContext:
actor_id: ActorID
execution_vertex: ExecutionVertex
def __init__(
self,
execution_vertex_context_pb: remote_call_pb.ExecutionVertexContext
):
self.execution_vertex = ExecutionVertex(
execution_vertex_context_pb.current_execution_vertex)
self.job_name = self.execution_vertex.config[Config.STREAMING_JOB_NAME]
self.exe_vertex_name = self.execution_vertex.execution_vertex_name
self.actor_id = self.execution_vertex.worker_actor._ray_actor_id
def __init__(self,
vertex_context_pb: remote_call_pb.ExecutionVertexContext):
self.execution_vertex = \
ExecutionVertex(vertex_context_pb.current_execution_vertex)
self.upstream_execution_vertices = [
ExecutionVertex(vertex) for vertex in
execution_vertex_context_pb.upstream_execution_vertices
ExecutionVertex(vertex)
for vertex in vertex_context_pb.upstream_execution_vertices
]
self.downstream_execution_vertices = [
ExecutionVertex(vertex) for vertex in
execution_vertex_context_pb.downstream_execution_vertices
ExecutionVertex(vertex)
for vertex in vertex_context_pb.downstream_execution_vertices
]
self.input_execution_edges = [
ExecutionEdge(edge, self.execution_vertex.language)
for edge in execution_vertex_context_pb.input_execution_edges
for edge in vertex_context_pb.input_execution_edges
]
self.output_execution_edges = [
ExecutionEdge(edge, self.execution_vertex.language)
for edge in execution_vertex_context_pb.output_execution_edges
for edge in vertex_context_pb.output_execution_edges
]
def get_parallelism(self):
@@ -140,16 +112,16 @@ class ExecutionVertexContext:
def get_task_id(self):
return self.execution_vertex.execution_vertex_id
def get_source_actor_by_execution_vertex_id(self, execution_vertex_id):
for execution_vertex in self.upstream_execution_vertices:
if execution_vertex.execution_vertex_id == execution_vertex_id:
return execution_vertex.worker_actor
raise Exception(
"Vertex %s does not exist!".format(execution_vertex_id))
def get_source_actor_by_vertex_id(self, execution_vertex_id):
for vertex in self.upstream_execution_vertices:
if vertex.execution_vertex_id == execution_vertex_id:
return vertex.worker_actor
raise Exception("ExecutionVertex %s does not exist!"
.format(execution_vertex_id))
def get_target_actor_by_execution_vertex_id(self, execution_vertex_id):
for execution_vertex in self.downstream_execution_vertices:
if execution_vertex.execution_vertex_id == execution_vertex_id:
return execution_vertex.worker_actor
raise Exception(
"Vertex %s does not exist!".format(execution_vertex_id))
def get_target_actor_by_vertex_id(self, execution_vertex_id):
for vertex in self.downstream_execution_vertices:
if vertex.execution_vertex_id == execution_vertex_id:
return vertex.worker_actor
raise Exception("ExecutionVertex %s does not exist!"
.format(execution_vertex_id))
+3 -17
View File
@@ -23,14 +23,6 @@ class Processor(ABC):
def close(self):
pass
@abstractmethod
def save_checkpoint(self):
pass
@abstractmethod
def load_checkpoint(self, checkpoint_obj):
pass
class StreamingProcessor(Processor, ABC):
"""StreamingProcessor is a process unit for a operator."""
@@ -48,13 +40,7 @@ class StreamingProcessor(Processor, ABC):
logger.info("Opened Processor {}".format(self))
def close(self):
self.operator.close()
def save_checkpoint(self):
self.operator.save_checkpoint()
def load_checkpoint(self, checkpoint_obj):
self.operator.load_checkpoint(checkpoint_obj)
pass
class SourceProcessor(StreamingProcessor):
@@ -66,8 +52,8 @@ class SourceProcessor(StreamingProcessor):
def process(self, record):
raise Exception("SourceProcessor should not process record")
def fetch(self):
self.operator.fetch()
def run(self):
self.operator.run()
class OneInputProcessor(StreamingProcessor):
-95
View File
@@ -1,95 +0,0 @@
import logging
import os
import ray
import time
from enum import Enum
from ray.actor import ActorHandle
from ray.streaming.generated import remote_call_pb2
from ray.streaming.runtime.command\
import WorkerCommitReport, WorkerRollbackRequest
logger = logging.getLogger(__name__)
class CallResult:
"""
Call Result
"""
def __init__(self, success, result_code, result_msg, result_obj):
self.success = success
self.result_code = result_code
self.result_msg = result_msg
self.result_obj = result_obj
@staticmethod
def success(payload=None):
return CallResult(True, CallResultEnum.SUCCESS, None, payload)
@staticmethod
def fail(payload=None):
return CallResult(False, CallResultEnum.FAILED, None, payload)
@staticmethod
def skipped(msg=None):
return CallResult(True, CallResultEnum.SKIPPED, msg, None)
def is_success(self):
if self.result_code is CallResultEnum.SUCCESS:
return True
return False
class CallResultEnum(Enum):
"""
call result enum
"""
SUCCESS = 0
FAILED = 1
SKIPPED = 2
class RemoteCallMst:
"""
remote call job master
"""
@staticmethod
def request_job_worker_rollback(master: ActorHandle,
request: WorkerRollbackRequest):
logger.info("Remote call mst: request job worker rollback start.")
request_pb = remote_call_pb2.BaseWorkerCmd()
request_pb.actor_id = request.from_actor_id
request_pb.timestamp = int(time.time() * 1000.0)
rollback_request_pb = remote_call_pb2.WorkerRollbackRequest()
rollback_request_pb.exception_msg = request.exception_msg()
rollback_request_pb.worker_hostname = os.uname()[1]
rollback_request_pb.worker_pid = str(os.getpid())
request_pb.detail.Pack(rollback_request_pb)
return_ids = master.requestJobWorkerRollback\
.remote(request_pb.SerializeToString())
result = remote_call_pb2.BoolResult()
result.ParseFromString(ray.get(return_ids))
logger.info("Remote call mst: request job worker rollback finish.")
return result.boolRes
@staticmethod
def report_job_worker_commit(master: ActorHandle,
report: WorkerCommitReport):
logger.info("Remote call mst: report job worker commit start.")
report_pb = remote_call_pb2.BaseWorkerCmd()
report_pb.actor_id = report.from_actor_id
report_pb.timestamp = int(time.time() * 1000.0)
wk_commit = remote_call_pb2.WorkerCommitReport()
wk_commit.commit_checkpoint_id = report.commit_checkpoint_id
report_pb.detail.Pack(wk_commit)
return_id = master.reportJobWorkerCommit\
.remote(report_pb.SerializeToString())
result = remote_call_pb2.BoolResult()
result.ParseFromString(ray.get(return_id))
logger.info("Remote call mst: report job worker commit finish.")
return result.boolRes
+10 -10
View File
@@ -3,11 +3,11 @@ import pickle
import msgpack
from ray.streaming import message
RECORD_TYPE_ID = 0
KEY_RECORD_TYPE_ID = 1
CROSS_LANG_TYPE_ID = 0
JAVA_TYPE_ID = 1
PYTHON_TYPE_ID = 2
_RECORD_TYPE_ID = 0
_KEY_RECORD_TYPE_ID = 1
_CROSS_LANG_TYPE_ID = b"0"
_JAVA_TYPE_ID = b"1"
_PYTHON_TYPE_ID = b"2"
class Serializer(ABC):
@@ -33,21 +33,21 @@ class CrossLangSerializer(Serializer):
def serialize(self, obj):
if type(obj) is message.Record:
fields = [RECORD_TYPE_ID, obj.stream, obj.value]
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
elif type(obj) is message.KeyRecord:
fields = [KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
else:
raise Exception("Unsupported value {}".format(obj))
return msgpack.packb(fields, use_bin_type=True)
def deserialize(self, data):
fields = msgpack.unpackb(data, raw=False)
if fields[0] == RECORD_TYPE_ID:
fields = msgpack.unpackb(data, raw=False, strict_map_key=False)
if fields[0] == _RECORD_TYPE_ID:
stream, value = fields[1:]
record = message.Record(value)
record.stream = stream
return record
elif fields[0] == KEY_RECORD_TYPE_ID:
elif fields[0] == _KEY_RECORD_TYPE_ID:
stream, key, value = fields[1:]
key_record = message.KeyRecord(key, value)
key_record.stream = stream
+54 -263
View File
@@ -1,30 +1,14 @@
import logging
import pickle
import threading
import time
import typing
from abc import ABC, abstractmethod
from typing import Optional
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.generated import remote_call_pb2
from ray.streaming.runtime import serialization
from ray.streaming.runtime.command import WorkerCommitReport
from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo
from ray.streaming.runtime.remote_call import RemoteCallMst
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import CheckpointBarrier
from ray.streaming.runtime.transfer import DataMessage
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
from ray.streaming.runtime.transfer import ChannelRecoverInfo
from ray.streaming.runtime.transfer import ChannelInterruptException
if typing.TYPE_CHECKING:
from ray.streaming.runtime.worker import JobWorker
from ray.streaming.runtime.processor import Processor, SourceProcessor
logger = logging.getLogger(__name__)
@@ -32,85 +16,18 @@ logger = logging.getLogger(__name__)
class StreamTask(ABC):
"""Base class for all streaming tasks. Each task runs a processor."""
def __init__(self, task_id: int, processor: "Processor",
worker: "JobWorker", last_checkpoint_id: int):
self.worker_context = worker.worker_context
self.vertex_context = worker.execution_vertex_context
def __init__(self, task_id, processor, worker):
self.task_id = task_id
self.processor = processor
self.worker = worker
self.config: dict = worker.config
self.reader: Optional[DataReader] = None
self.writer: Optional[DataWriter] = None
self.is_initial_state = True
self.last_checkpoint_id: int = last_checkpoint_id
self.config = worker.config
self.reader = None # DataReader
self.writers = {} # ExecutionEdge -> DataWriter
self.thread = None
self.prepare_task()
self.thread = threading.Thread(target=self.run, daemon=True)
def do_checkpoint(self, checkpoint_id: int, input_points):
logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format(
checkpoint_id, input_points))
output_points = None
if self.writer is not None:
output_points = self.writer.get_output_checkpoints()
operator_checkpoint = self.processor.save_checkpoint()
op_checkpoint_info = OpCheckpointInfo(
operator_checkpoint, input_points, output_points, checkpoint_id)
self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id)
barrier_pb = remote_call_pb2.Barrier()
barrier_pb.id = checkpoint_id
byte_buffer = barrier_pb.SerializeToString()
if self.writer is not None:
self.writer.broadcast_barrier(checkpoint_id, byte_buffer)
logger.info("Operator checkpoint {} finish.".format(checkpoint_id))
def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id):
logger.info(
"Start to save cp state and report, checkpoint id is {}.".format(
checkpoint_id))
self.__save_cp(op_checkpoint_info, checkpoint_id)
self.__report_commit(checkpoint_id)
self.last_checkpoint_id = checkpoint_id
def __save_cp(self, op_checkpoint_info, checkpoint_id):
logger.info("save operator cp, op_checkpoint_info={}".format(
op_checkpoint_info))
cp_bytes = pickle.dumps(op_checkpoint_info)
self.worker.context_backend.put(
self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes)
def __report_commit(self, checkpoint_id: int):
logger.info("Report commit, checkpoint id {}.".format(checkpoint_id))
report = WorkerCommitReport(self.vertex_context.actor_id.binary(),
checkpoint_id)
RemoteCallMst.report_job_worker_commit(self.worker.master_actor,
report)
def clear_expired_cp_state(self, checkpoint_id):
cp_key = self.__gen_op_checkpoint_key(checkpoint_id)
self.worker.context_backend.remove(cp_key)
def clear_expired_queue_msg(self, checkpoint_id):
# clear operator checkpoint
if self.writer is not None:
self.writer.clear_checkpoint(checkpoint_id)
def request_rollback(self, exception_msg: str):
self.worker.request_rollback(exception_msg)
def __gen_op_checkpoint_key(self, checkpoint_id):
op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str(
self.vertex_context.job_name) + "_" + str(
self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id)
logger.info(
"Generate op checkpoint key {}. ".format(op_checkpoint_key))
return op_checkpoint_key
def prepare_task(self, is_recreate: bool):
logger.info(
"Preparing stream task, is_recreate={}.".format(is_recreate))
def prepare_task(self):
channel_conf = dict(self.worker.config)
channel_size = int(
self.worker.config.get(Config.CHANNEL_SIZE,
@@ -122,76 +39,45 @@ class StreamTask(ABC):
execution_vertex_context = self.worker.execution_vertex_context
build_time = execution_vertex_context.build_time
# when use memory state, if actor throw exception, will miss state
op_checkpoint_info = OpCheckpointInfo()
cp_bytes = None
# get operator checkpoint
if is_recreate:
cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id)
logger.info("Getting task checkpoints from state, "
"cpKey={}, checkpointId={}.".format(
cp_key, self.last_checkpoint_id))
cp_bytes = self.worker.context_backend.get(cp_key)
if cp_bytes is None:
msg = "Task recover failed, checkpoint is null!"\
"cpKey={}".format(cp_key)
raise RuntimeError(msg)
if cp_bytes is not None:
op_checkpoint_info = pickle.loads(cp_bytes)
self.processor.load_checkpoint(op_checkpoint_info.operator_point)
logger.info("Stream task recover from checkpoint state,"
"checkpoint bytes len={}, checkpointInfo={}.".format(
cp_bytes.__len__(), op_checkpoint_info))
# writers
collectors = []
output_actors_map = {}
for edge in execution_vertex_context.output_execution_edges:
target_task_id = edge.target_execution_vertex_id
target_actor = execution_vertex_context \
.get_target_actor_by_execution_vertex_id(target_task_id)
target_actor = execution_vertex_context\
.get_target_actor_by_vertex_id(target_task_id)
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
build_time)
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_str_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info("Create DataWriter channel_ids {},"
"target_actors {}, output_points={}.".format(
channel_str_ids, target_actors,
op_checkpoint_info.output_points))
self.writer = DataWriter(channel_str_ids, target_actors,
channel_conf)
logger.info("Create DataWriter succeed channel_ids {}, "
"target_actors {}.".format(channel_str_ids,
target_actors))
for edge in execution_vertex_context.output_execution_edges:
if len(output_actors_map) > 0:
channel_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info(
"Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
OutputCollector(self.writer, channel_str_ids,
target_actors, edge.partition))
OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers
input_actor_map = {}
for edge in execution_vertex_context.input_execution_edges:
source_task_id = edge.source_execution_vertex_id
source_actor = execution_vertex_context \
.get_source_actor_by_execution_vertex_id(source_task_id)
source_actor = execution_vertex_context\
.get_source_actor_by_vertex_id(source_task_id)
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
build_time)
input_actor_map[channel_name] = source_actor
if len(input_actor_map) > 0:
channel_str_ids = list(input_actor_map.keys())
channel_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {},"
"input_actors {}, input_points={}.".format(
channel_str_ids, from_actors,
op_checkpoint_info.input_points))
self.reader = DataReader(channel_str_ids, from_actors,
channel_conf)
logger.info("Create DataReader, channels {}, input_actors {}."
.format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
@@ -201,31 +87,21 @@ class StreamTask(ABC):
import atexit
atexit.register(exit_handler)
# TODO(chaokunyang) add task/job config
runtime_context = RuntimeContextImpl(
self.worker.task_id,
execution_vertex_context.execution_vertex.execution_vertex_index,
execution_vertex_context.get_parallelism(),
config=channel_conf,
job_config=channel_conf)
execution_vertex_context.get_parallelism())
logger.info("open Processor {}".format(self.processor))
self.processor.open(collectors, runtime_context)
# immediately save cp. In case of FO in cp 0
# or use old cp in multi node FO.
self.__save_cp(op_checkpoint_info, self.last_checkpoint_id)
def recover(self, is_recreate: bool):
self.prepare_task(is_recreate)
recover_info = ChannelRecoverInfo()
if self.reader is not None:
recover_info = self.reader.get_channel_recover_info()
@abstractmethod
def init(self):
pass
def start(self):
self.thread.start()
logger.info("Start operator success.")
return recover_info
@abstractmethod
def run(self):
pass
@@ -234,24 +110,14 @@ class StreamTask(ABC):
def cancel_task(self):
pass
@abstractmethod
def commit_trigger(self, barrier: Barrier) -> bool:
pass
class InputStreamTask(StreamTask):
"""Base class for stream tasks that execute a
:class:`runtime.processor.OneInputProcessor` or
:class:`runtime.processor.TwoInputProcessor` """
def commit_trigger(self, barrier):
raise RuntimeError(
"commit_trigger is only supported in SourceStreamTask.")
def __init__(self, task_id, processor_instance, worker,
last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
self.running = True
self.stopped = False
self.read_timeout_millis = \
@@ -260,58 +126,25 @@ class InputStreamTask(StreamTask):
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
def run(self):
logger.info("Input task thread start.")
try:
while self.running:
self.worker.initial_state_lock.acquire()
try:
item = self.reader.read(self.read_timeout_millis)
self.is_initial_state = False
finally:
self.worker.initial_state_lock.release()
if item is None:
continue
if isinstance(item, DataMessage):
msg_data = item.body
type_id = msg_data[0]
if type_id == serialization.PYTHON_TYPE_ID:
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(
msg_data[1:])
self.processor.process(msg)
elif isinstance(item, CheckpointBarrier):
logger.info("Got barrier:{}".format(item))
logger.info("Start to do checkpoint {}.".format(
item.checkpoint_id))
input_points = item.get_input_checkpoints()
self.do_checkpoint(item.checkpoint_id, input_points)
logger.info("Do checkpoint {} success.".format(
item.checkpoint_id))
while self.running:
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
raise RuntimeError(
"Unknown item type! item={}".format(item))
except ChannelInterruptException:
logger.info("queue has stopped.")
except BaseException as e:
logger.exception(
"Last success checkpointId={}, now occur error.".format(
self.last_checkpoint_id))
self.request_rollback(str(e))
logger.info("Source fetcher thread exit.")
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True
def cancel_task(self):
self.running = False
while not self.stopped:
time.sleep(0.5)
pass
@@ -319,64 +152,22 @@ class OneInputStreamTask(InputStreamTask):
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
"""
def __init__(self, task_id, processor_instance, worker,
last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
class SourceStreamTask(StreamTask):
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
"""
processor: "SourceProcessor"
def __init__(self, task_id: int, processor_instance: "SourceProcessor",
worker: "JobWorker", last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
self.running = True
self.stopped = False
self.__pending_barrier: Optional[Barrier] = None
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
def init(self):
pass
def run(self):
logger.info("Source task thread start.")
try:
while self.running:
self.processor.fetch()
# check checkpoint
if self.__pending_barrier is not None:
# source fetcher only have outputPoints
barrier = self.__pending_barrier
logger.info("Start to do checkpoint {}.".format(
barrier.id))
self.do_checkpoint(barrier.id, barrier)
logger.info("Finish to do checkpoint {}.".format(
barrier.id))
self.__pending_barrier = None
except ChannelInterruptException:
logger.info("queue has stopped.")
except Exception as e:
logger.exception(
"Last success checkpointId={}, now occur error.".format(
self.last_checkpoint_id))
self.request_rollback(str(e))
logger.info("Source fetcher thread exit.")
self.stopped = True
def commit_trigger(self, barrier):
if self.__pending_barrier is not None:
logger.warning(
"Last barrier is not broadcast now, skip this barrier trigger."
)
return False
self.__pending_barrier = barrier
return True
self.processor.run()
def cancel_task(self):
self.running = False
while not self.stopped:
time.sleep(0.5)
pass
pass
+24 -156
View File
@@ -2,8 +2,6 @@ import logging
import random
from queue import Queue
from typing import List
from enum import Enum
from abc import ABC, abstractmethod
import ray
import ray.streaming._streaming as _streaming
@@ -15,7 +13,6 @@ from ray._raylet import PythonFunctionDescriptor
from ray._raylet import Language
CHANNEL_ID_LEN = 20
logger = logging.getLogger(__name__)
class ChannelID:
@@ -100,70 +97,40 @@ def channel_bytes_to_str(id_bytes):
return bytes.hex(id_bytes)
class Message(ABC):
@property
@abstractmethod
def body(self):
"""Message data"""
pass
@property
@abstractmethod
def timestamp(self):
"""Get timestamp when item is written by upstream DataWriter
"""
pass
@property
@abstractmethod
def channel_id(self):
"""Get string id of channel where data is coming from"""
pass
@property
@abstractmethod
def message_id(self):
"""Get message id of the message"""
pass
class DataMessage(Message):
class DataMessage:
"""
DataMessage represents data between upstream and downstream operator.
DataMessage represents data between upstream and downstream operator
"""
def __init__(self,
body,
timestamp,
message_id,
channel_id,
message_id_,
is_empty_message=False):
self.__body = body
self.__timestamp = timestamp
self.__channel_id = channel_id
self.__message_id = message_id
self.__message_id = message_id_
self.__is_empty_message = is_empty_message
def __len__(self):
return len(self.__body)
@property
def body(self):
"""Message data"""
return self.__body
@property
def timestamp(self):
"""Get timestamp when item is written by upstream DataWriter
"""
return self.__timestamp
@property
def channel_id(self):
"""Get string id of channel where data is coming from
"""
return self.__channel_id
@property
def message_id(self):
return self.__message_id
@property
def is_empty_message(self):
"""Whether this message is an empty message.
Upstream DataWriter will send an empty message when this is no data
@@ -171,47 +138,10 @@ class DataMessage(Message):
"""
return self.__is_empty_message
class CheckpointBarrier(Message):
"""
CheckpointBarrier separates the records in the data stream into the set of
records that goes into the current snapshot, and the records that go into
the next snapshot. Each barrier carries the ID of the snapshot whose
records it pushed in front of it.
"""
def __init__(self, barrier_data, timestamp, message_id, channel_id,
offsets, barrier_id, barrier_type):
self.__barrier_data = barrier_data
self.__timestamp = timestamp
self.__message_id = message_id
self.__channel_id = channel_id
self.checkpoint_id = barrier_id
self.offsets = offsets
self.barrier_type = barrier_type
@property
def body(self):
return self.__barrier_data
@property
def timestamp(self):
return self.__timestamp
@property
def channel_id(self):
return self.__channel_id
@property
def message_id(self):
return self.__message_id
def get_input_checkpoints(self):
return self.offsets
def __str__(self):
return "Barrier(Checkpoint id : {})".format(self.checkpoint_id)
class ChannelCreationParametersBuilder:
"""
@@ -288,6 +218,9 @@ class ChannelCreationParametersBuilder:
_python_reader_sync_function_descriptor = sync_function
logger = logging.getLogger(__name__)
class DataWriter:
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
to downstream workers
@@ -331,26 +264,6 @@ class DataWriter:
msg_id = self.writer.write(channel_id.object_qid, item)
return msg_id
def broadcast_barrier(self, checkpoint_id: int, body: bytes):
"""Broadcast barriers to all downstream channels
Args:
checkpoint_id: the checkpoint_id
body: barrier payload
"""
self.writer.broadcast_barrier(checkpoint_id, body)
def get_output_checkpoints(self) -> List[int]:
"""Get output offsets of all downstream channels
Returns:
a list contains current msg_id of each downstream channel
"""
return self.writer.get_output_checkpoints()
def clear_checkpoint(self, checkpoint_id):
logger.info("producer start to clear checkpoint, checkpoint_id={}"
.format(checkpoint_id))
self.writer.clear_checkpoint(checkpoint_id)
def stop(self):
logger.info("stopping channel writer.")
self.writer.stop()
@@ -381,20 +294,18 @@ class DataReader:
]
creation_parameters = ChannelCreationParametersBuilder()
creation_parameters.build_input_queue_parameters(from_actors)
py_seq_ids = [0 for _ in range(len(input_channels))]
py_msg_ids = [0 for _ in range(len(input_channels))]
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
config_bytes = _to_native_conf(conf)
self.__queue = Queue(10000)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.reader, queues_creation_status = _streaming.DataReader.create(
self.reader = _streaming.DataReader.create(
py_input_channels, creation_parameters.get_parameters(),
py_msg_ids, timer_interval, config_bytes, is_mock)
self.__creation_status = {}
for q, status in queues_creation_status.items():
self.__creation_status[q] = ChannelCreationStatus(status)
logger.info("create DataReader succeed, creation_status={}".format(
self.__creation_status))
py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes,
is_mock)
logger.info("create DataReader succeed")
def read(self, timeout_millis):
"""Read data from channel
@@ -405,17 +316,16 @@ class DataReader:
channel item
"""
if self.__queue.empty():
messages = self.reader.read(timeout_millis)
for message in messages:
self.__queue.put(message)
msgs = self.reader.read(timeout_millis)
for msg in msgs:
msg_bytes, msg_id, timestamp, qid_bytes = msg
data_msg = DataMessage(msg_bytes, timestamp,
channel_bytes_to_str(qid_bytes), msg_id)
self.__queue.put(data_msg)
if self.__queue.empty():
return None
return self.__queue.get()
def get_channel_recover_info(self):
return ChannelRecoverInfo(self.__creation_status)
def stop(self):
logger.info("stopping Data Reader.")
self.reader.stop()
@@ -462,45 +372,3 @@ class ChannelInitException(Exception):
class ChannelInterruptException(Exception):
def __init__(self, msg=None):
self.msg = msg
class ChannelRecoverInfo:
def __init__(self, queue_creation_status_map=None):
if queue_creation_status_map is None:
queue_creation_status_map = {}
self.__queue_creation_status_map = queue_creation_status_map
def get_creation_status(self):
return self.__queue_creation_status_map
def get_data_lost_queues(self):
data_lost_queues = set()
for (q, status) in self.__queue_creation_status_map.items():
if status == ChannelCreationStatus.DataLost:
data_lost_queues.add(q)
return data_lost_queues
def __str__(self):
return "QueueRecoverInfo [dataLostQueues=%s]" \
% (self.get_data_lost_queues())
class ChannelCreationStatus(Enum):
FreshStarted = 0
PullOk = 1
Timeout = 2
DataLost = 3
def channel_id_bytes_to_str(id_bytes):
"""
Args:
id_bytes: bytes representation of channel id
Returns:
string representation of channel id
"""
assert type(id_bytes) in [str, bytes]
if isinstance(id_bytes, str):
return id_bytes
return bytes.hex(id_bytes)
+43 -323
View File
@@ -1,23 +1,12 @@
import enum
import logging.config
import os
import threading
import time
from typing import Optional
import logging
import ray
import ray.streaming.runtime.processor as processor
from ray.actor import ActorHandle
from ray.streaming.generated import remote_call_pb2
from ray.streaming.runtime.command import WorkerRollbackRequest
from ray.streaming.runtime.failover import Barrier
from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex
from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst
from ray.streaming.runtime.context_backend import ContextBackendFactory
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
from ray.streaming.runtime.transfer import channel_bytes_to_str
from ray.streaming.config import Config
import ray.streaming._streaming as _streaming
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
import ray.streaming.runtime.processor as processor
from ray.streaming.config import Config
from ray.streaming.runtime.graph import ExecutionVertexContext
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__)
@@ -29,179 +18,74 @@ _NOT_READY_FLAG_ = b" " * 4
class JobWorker(object):
"""A streaming job worker is used to execute user-defined function and
interact with `JobMaster`"""
master_actor: Optional[ActorHandle]
worker_context: Optional[remote_call_pb2.PythonJobWorkerContext]
execution_vertex_context: Optional[ExecutionVertexContext]
__need_rollback: bool
def __init__(self, execution_vertex_pb_bytes):
logger.info("Creating job worker, pid={}".format(os.getpid()))
execution_vertex_pb = remote_call_pb2\
.ExecutionVertexContext.ExecutionVertex()
execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes)
self.execution_vertex = ExecutionVertex(execution_vertex_pb)
self.config = self.execution_vertex.config
def __init__(self):
self.worker_context = None
self.execution_vertex_context = None
self.config = None
self.task_id = None
self.task = None
self.stream_processor = None
self.master_actor = None
self.context_backend = ContextBackendFactory.get_context_backend(
self.config)
self.initial_state_lock = threading.Lock()
self.__rollback_cnt: int = 0
self.__is_recreate: bool = False
self.__state = WorkerState()
self.__need_rollback = True
self.reader_client = None
self.writer_client = None
try:
# load checkpoint
was_reconstructed = ray.get_runtime_context(
).was_current_actor_reconstructed
logger.info(
"Worker was reconstructed: {}".format(was_reconstructed))
if was_reconstructed:
job_worker_context_key = self.__get_job_worker_context_key()
logger.info("Worker get checkpoint state by key: {}.".format(
job_worker_context_key))
context_bytes = self.context_backend.get(
job_worker_context_key)
if context_bytes is not None and context_bytes.__len__() > 0:
self.init(context_bytes)
self.request_rollback(
"Python worker recover from checkpoint.")
else:
logger.error(
"Error! Worker get checkpoint state by key {}"
" returns None, please check your state backend"
", only reliable state backend supports fail-over."
.format(job_worker_context_key))
except Exception:
logger.exception("Error in __init__ of JobWorker")
logger.info("Creating job worker succeeded. worker config {}".format(
self.config))
logger.info("Creating job worker succeeded.")
def init(self, worker_context_bytes):
logger.info("Start to init job worker")
try:
# deserialize context
worker_context = remote_call_pb2.PythonJobWorkerContext()
worker_context.ParseFromString(worker_context_bytes)
self.worker_context = worker_context
self.master_actor = ActorHandle._deserialization_helper(
worker_context.master_actor)
worker_context = remote_call_pb.PythonJobWorkerContext()
worker_context.ParseFromString(worker_context_bytes)
self.worker_context = worker_context
# build vertex context from pb
self.execution_vertex_context = ExecutionVertexContext(
worker_context.execution_vertex_context)
self.execution_vertex = self\
.execution_vertex_context.execution_vertex
# build vertex context from pb
self.execution_vertex_context = ExecutionVertexContext(
worker_context.execution_vertex_context)
# save context
job_worker_context_key = self.__get_job_worker_context_key()
self.context_backend.put(job_worker_context_key,
worker_context_bytes)
# use vertex id as task id
self.task_id = self.execution_vertex_context.get_task_id()
# use vertex id as task id
self.task_id = self.execution_vertex_context.get_task_id()
# build and get processor from operator
operator = self.execution_vertex_context.stream_operator
self.stream_processor = processor.build_processor(operator)
logger.info("Initializing job worker, exe_vertex_name={},"
"task_id: {}, operator: {}, pid={}".format(
self.execution_vertex_context.exe_vertex_name,
self.task_id, self.stream_processor, os.getpid()))
# build and get processor from operator
operator = self.execution_vertex_context.stream_operator
self.stream_processor = processor.build_processor(operator)
logger.info(
"Initializing job worker, task_id: {}, operator: {}.".format(
self.task_id, self.stream_processor))
# get config from vertex
self.config = self.execution_vertex_context.config
# get config from vertex
self.config = self.execution_vertex_context.config
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
self.reader_client = _streaming.ReaderClient()
self.writer_client = _streaming.WriterClient()
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
self.reader_client = _streaming.ReaderClient()
self.writer_client = _streaming.WriterClient()
logger.info("Job worker init succeeded.")
except Exception:
logger.exception("Error when init job worker.")
return False
self.task = self.create_stream_task()
logger.info("Job worker init succeeded.")
return True
def create_stream_task(self, checkpoint_id):
def start(self):
self.task.start()
logger.info("Job worker start succeeded.")
def create_stream_task(self):
if isinstance(self.stream_processor, processor.SourceProcessor):
return SourceStreamTask(self.task_id, self.stream_processor, self,
checkpoint_id)
return SourceStreamTask(self.task_id, self.stream_processor, self)
elif isinstance(self.stream_processor, processor.OneInputProcessor):
return OneInputStreamTask(self.task_id, self.stream_processor,
self, checkpoint_id)
self)
else:
raise Exception("Unsupported processor type: " +
str(type(self.stream_processor)))
type(self.stream_processor))
def rollback(self, checkpoint_id_bytes):
checkpoint_id_pb = remote_call_pb2.CheckpointId()
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
checkpoint_id = checkpoint_id_pb.checkpoint_id
logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id))
self.__rollback_cnt += 1
if self.__rollback_cnt > 1:
self.__is_recreate = True
# skip useless rollback
self.initial_state_lock.acquire()
try:
if self.task is not None and self.task.thread.is_alive()\
and checkpoint_id == self.task.last_checkpoint_id\
and self.task.is_initial_state:
logger.info(
"Task is already in initial state, skip this rollback.")
return self.__gen_call_result(
CallResult.skipped(
"Task is already in initial state, skip this rollback."
))
finally:
self.initial_state_lock.release()
# restart task
try:
if self.task is not None:
# make sure the runner is closed
self.task.cancel_task()
del self.task
self.task = self.create_stream_task(checkpoint_id)
q_recover_info = self.task.recover(self.__is_recreate)
self.__state.set_type(StateType.RUNNING)
self.__need_rollback = False
logger.info(
"Rollback success, checkpoint is {}, qRecoverInfo is {}.".
format(checkpoint_id, q_recover_info))
return self.__gen_call_result(CallResult.success(q_recover_info))
except Exception:
logger.exception("Rollback has exception.")
return self.__gen_call_result(CallResult.fail())
def on_reader_message(self, *buffers):
def on_reader_message(self, buffer: bytes):
"""Called by upstream queue writer to send data message to downstream
queue reader.
"""
if self.reader_client is None:
logger.info("reader_client is None, skip writer transfer")
return
self.reader_client.on_reader_message(*buffers)
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""Called by upstream queue writer to send
control message to downstream downstream queue reader.
"""Called by upstream queue writer to send control message to downstream
downstream queue reader.
"""
if self.reader_client is None:
logger.info("task is None, skip reader transfer")
return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
@@ -210,9 +94,6 @@ class JobWorker(object):
"""Called by downstream queue reader to send notify message to
upstream queue writer.
"""
if self.writer_client is None:
logger.info("writer_client is None, skip writer transfer")
return
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
@@ -223,164 +104,3 @@ class JobWorker(object):
return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
def shutdown_without_reconstruction(self):
logger.info("Python worker shutdown without reconstruction.")
ray.actor.exit_actor()
def notify_checkpoint_timeout(self, checkpoint_id_bytes):
pass
def commit(self, barrier_bytes):
barrier_pb = remote_call_pb2.Barrier()
barrier_pb.ParseFromString(barrier_bytes)
barrier = Barrier(barrier_pb.id)
logger.info("Receive trigger, barrier is {}.".format(barrier))
if self.task is not None:
self.task.commit_trigger(barrier)
ret = remote_call_pb2.BoolResult()
ret.boolRes = True
return ret.SerializeToString()
def clear_expired_cp(self, state_checkpoint_id_bytes,
queue_checkpoint_id_bytes):
state_checkpoint_id = self.__parse_to_checkpoint_id(
state_checkpoint_id_bytes)
queue_checkpoint_id = self.__parse_to_checkpoint_id(
queue_checkpoint_id_bytes)
logger.info("Start to clear expired checkpoint, checkpoint_id={},"
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
state_checkpoint_id, queue_checkpoint_id,
self.execution_vertex_context.exe_vertex_name))
ret = remote_call_pb2.BoolResult()
ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \
if state_checkpoint_id > 0 else True
ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id)
logger.info(
"Clear expired checkpoint done, result={}, checkpoint_id={},"
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
ret.boolRes, state_checkpoint_id, queue_checkpoint_id,
self.execution_vertex_context.exe_vertex_name))
return ret.SerializeToString()
def __clear_expired_cp_state(self, checkpoint_id):
if self.__need_rollback:
logger.warning("Need rollback, skip clear_expired_cp_state"
", checkpoint id: {}".format(checkpoint_id))
return False
logger.info("Clear expired checkpoint state, cp id is {}.".format(
checkpoint_id))
if self.task is not None:
self.task.clear_expired_cp_state(checkpoint_id)
return True
def __clear_expired_queue_msg(self, checkpoint_id):
if self.__need_rollback:
logger.warning("Need rollback, skip clear_expired_queue_msg"
", checkpoint id: {}".format(checkpoint_id))
return False
logger.info("Clear expired queue msg, checkpoint_id is {}.".format(
checkpoint_id))
if self.task is not None:
self.task.clear_expired_queue_msg(checkpoint_id)
return True
def __parse_to_checkpoint_id(self, checkpoint_id_bytes):
checkpoint_id_pb = remote_call_pb2.CheckpointId()
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
return checkpoint_id_pb.checkpoint_id
def check_if_need_rollback(self):
ret = remote_call_pb2.BoolResult()
ret.boolRes = self.__need_rollback
return ret.SerializeToString()
def request_rollback(self, exception_msg="Python exception."):
logger.info("Request rollback.")
self.__need_rollback = True
self.__is_recreate = True
request_ret = False
for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES):
logger.info("request rollback {} time".format(i))
try:
request_ret = RemoteCallMst.request_job_worker_rollback(
self.master_actor,
WorkerRollbackRequest(
self.execution_vertex_context.actor_id.binary(),
"Exception msg=%s, retry time=%d." % (exception_msg,
i)))
except Exception:
logger.exception("Unexpected error when rollback")
logger.info("request rollback {} time, ret={}".format(
i, request_ret))
if not request_ret:
logger.warning(
"Request rollback return false"
", maybe it's invalid request, try to sleep 1s.")
time.sleep(1)
else:
break
if not request_ret:
logger.warning("Request failed after retry {} times,"
"now worker shutdown without reconstruction."
.format(Config.REQUEST_ROLLBACK_RETRY_TIMES))
self.shutdown_without_reconstruction()
self.__state.set_type(StateType.WAIT_ROLLBACK)
def __gen_call_result(self, call_result):
call_result_pb = remote_call_pb2.CallResult()
call_result_pb.success = call_result.success
call_result_pb.result_code = call_result.result_code.value
if call_result.result_msg is not None:
call_result_pb.result_msg = call_result.result_msg
if call_result.result_obj is not None:
q_recover_info = call_result.result_obj
for q, status in q_recover_info.get_creation_status().items():
call_result_pb.result_obj.creation_status[channel_bytes_to_str(
q)] = status.value
return call_result_pb.SerializeToString()
def _gen_unique_key(self, key_prefix):
return key_prefix \
+ str(self.config.get(Config.STREAMING_JOB_NAME)) \
+ "_" + str(self.execution_vertex.execution_vertex_id)
def __get_job_worker_context_key(self) -> str:
return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY)
class WorkerState:
"""
worker state
"""
def __init__(self):
self.__type = StateType.INIT
def set_type(self, type):
self.__type = type
def get_type(self):
return self.__type
class StateType(enum.Enum):
"""
state type
"""
INIT = 1
RUNNING = 2
WAIT_ROLLBACK = 3