mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:36:30 +08:00
This reverts commit 1b1466748f.
This commit is contained in:
@@ -1,30 +0,0 @@
|
||||
class BaseWorkerCmd:
|
||||
"""
|
||||
base worker cmd
|
||||
"""
|
||||
|
||||
def __init__(self, actor_id):
|
||||
self.from_actor_id = actor_id
|
||||
|
||||
|
||||
class WorkerCommitReport(BaseWorkerCmd):
|
||||
"""
|
||||
worker commit report
|
||||
"""
|
||||
|
||||
def __init__(self, actor_id, commit_checkpoint_id):
|
||||
super().__init__(actor_id)
|
||||
self.commit_checkpoint_id = commit_checkpoint_id
|
||||
|
||||
|
||||
class WorkerRollbackRequest(BaseWorkerCmd):
|
||||
"""
|
||||
worker rollback request
|
||||
"""
|
||||
|
||||
def __init__(self, actor_id, exception_msg):
|
||||
super().__init__(actor_id)
|
||||
self.__exception_msg = exception_msg
|
||||
|
||||
def exception_msg(self):
|
||||
return self.__exception_msg
|
||||
@@ -1,117 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from os import path
|
||||
|
||||
from ray.streaming.config import ConfigHelper, Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextBackend(ABC):
|
||||
@abstractmethod
|
||||
def get(self, key):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, key, value):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, key):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryContextBackend(ContextBackend):
|
||||
def __init__(self, conf):
|
||||
self.__dic = dict()
|
||||
|
||||
def get(self, key):
|
||||
return self.__dic.get(key)
|
||||
|
||||
def put(self, key, value):
|
||||
self.__dic[key] = value
|
||||
|
||||
def remove(self, key):
|
||||
if key in self.__dic:
|
||||
del self.__dic[key]
|
||||
|
||||
|
||||
class LocalFileContextBackend(ContextBackend):
|
||||
def __init__(self, conf):
|
||||
self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf)
|
||||
logger.info("Start init local file state backend, root_dir={}.".format(
|
||||
self.__dir))
|
||||
try:
|
||||
os.mkdir(self.__dir)
|
||||
except FileExistsError:
|
||||
logger.info("dir already exists, skipped.")
|
||||
|
||||
def put(self, key, value):
|
||||
logger.info("Put value of key {} start.".format(key))
|
||||
with open(self.__gen_file_path(key), "wb") as f:
|
||||
f.write(value)
|
||||
|
||||
def get(self, key):
|
||||
logger.info("Get value of key {} start.".format(key))
|
||||
full_path = self.__gen_file_path(key)
|
||||
if not os.path.isfile(full_path):
|
||||
return None
|
||||
with open(full_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def remove(self, key):
|
||||
logger.info("Remove value of key {} start.".format(key))
|
||||
try:
|
||||
os.remove(self.__gen_file_path(key))
|
||||
except Exception:
|
||||
# ignore exception
|
||||
pass
|
||||
|
||||
def rename(self, src, dst):
|
||||
logger.info("rename {} to {}".format(src, dst))
|
||||
os.rename(self.__gen_file_path(src), self.__gen_file_path(dst))
|
||||
|
||||
def exists(self, key) -> bool:
|
||||
return os.path.exists(key)
|
||||
|
||||
def __gen_file_path(self, key):
|
||||
return path.join(self.__dir, key)
|
||||
|
||||
|
||||
class AtomicFsContextBackend(LocalFileContextBackend):
|
||||
def __init__(self, conf):
|
||||
super().__init__(conf)
|
||||
self.__tmp_flag = "_tmp"
|
||||
|
||||
def put(self, key, value):
|
||||
tmp_key = key + self.__tmp_flag
|
||||
if super().exists(tmp_key) and not super().exists(key):
|
||||
super().rename(tmp_key, key)
|
||||
super().put(tmp_key, value)
|
||||
super().remove(key)
|
||||
super().rename(tmp_key, key)
|
||||
|
||||
def get(self, key):
|
||||
tmp_key = key + self.__tmp_flag
|
||||
if super().exists(tmp_key) and not super().exists(key):
|
||||
return super().get(tmp_key)
|
||||
return super().get(key)
|
||||
|
||||
def remove(self, key):
|
||||
tmp_key = key + self.__tmp_flag
|
||||
if super().exists(tmp_key):
|
||||
super().remove(tmp_key)
|
||||
super().remove(key)
|
||||
|
||||
|
||||
class ContextBackendFactory:
|
||||
@staticmethod
|
||||
def get_context_backend(worker_config) -> ContextBackend:
|
||||
backend_type = ConfigHelper.get_cp_context_backend_type(worker_config)
|
||||
context_backend = None
|
||||
if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE:
|
||||
context_backend = AtomicFsContextBackend(worker_config)
|
||||
elif backend_type == Config.CP_STATE_BACKEND_MEMORY:
|
||||
context_backend = MemoryContextBackend(worker_config)
|
||||
return context_backend
|
||||
@@ -1,30 +0,0 @@
|
||||
class Barrier:
|
||||
"""
|
||||
barrier
|
||||
"""
|
||||
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __str__(self):
|
||||
return "Barrier [id:%s]" % self.id
|
||||
|
||||
|
||||
class OpCheckpointInfo:
|
||||
"""
|
||||
operator checkpoint info
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
operator_point=None,
|
||||
input_points=None,
|
||||
output_points=None,
|
||||
checkpoint_id=None):
|
||||
if input_points is None:
|
||||
input_points = {}
|
||||
if output_points is None:
|
||||
output_points = {}
|
||||
self.operator_point = operator_point
|
||||
self.input_points = input_points
|
||||
self.output_points = output_points
|
||||
self.checkpoint_id = checkpoint_id
|
||||
@@ -5,9 +5,6 @@ import ray
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
import ray.streaming.operator as operator
|
||||
import ray.streaming.partition as partition
|
||||
from ray._raylet import ActorID
|
||||
from ray.actor import ActorHandle
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.generated.streaming_pb2 import Language
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,12 +27,10 @@ class NodeType(enum.Enum):
|
||||
|
||||
|
||||
class ExecutionEdge:
|
||||
def __init__(self, execution_edge_pb, language):
|
||||
self.source_execution_vertex_id = execution_edge_pb \
|
||||
.source_execution_vertex_id
|
||||
self.target_execution_vertex_id = execution_edge_pb \
|
||||
.target_execution_vertex_id
|
||||
partition_bytes = execution_edge_pb.partition
|
||||
def __init__(self, edge_pb, language):
|
||||
self.source_execution_vertex_id = edge_pb.source_execution_vertex_id
|
||||
self.target_execution_vertex_id = edge_pb.target_execution_vertex_id
|
||||
partition_bytes = edge_pb.partition
|
||||
# Sink node doesn't have partition function,
|
||||
# so we only deserialize partition_bytes when it's not None or empty
|
||||
if language == Language.PYTHON and partition_bytes:
|
||||
@@ -43,73 +38,50 @@ class ExecutionEdge:
|
||||
|
||||
|
||||
class ExecutionVertex:
|
||||
worker_actor: ActorHandle
|
||||
|
||||
def __init__(self, execution_vertex_pb):
|
||||
self.execution_vertex_id = execution_vertex_pb.execution_vertex_id
|
||||
self.execution_job_vertex_id = execution_vertex_pb \
|
||||
.execution_job_vertex_id
|
||||
self.execution_job_vertex_name = execution_vertex_pb \
|
||||
.execution_job_vertex_name
|
||||
self.execution_vertex_index = execution_vertex_pb\
|
||||
.execution_vertex_index
|
||||
self.parallelism = execution_vertex_pb.parallelism
|
||||
if execution_vertex_pb\
|
||||
.language == Language.PYTHON:
|
||||
# python operator descriptor
|
||||
operator_bytes = execution_vertex_pb.operator
|
||||
if execution_vertex_pb.chained:
|
||||
def __init__(self, vertex_pb):
|
||||
self.execution_vertex_id = vertex_pb.execution_vertex_id
|
||||
self.execution_job_vertex_Id = vertex_pb.execution_job_vertex_Id
|
||||
self.execution_job_vertex_name = vertex_pb.execution_job_vertex_name
|
||||
self.execution_vertex_index = vertex_pb.execution_vertex_index
|
||||
self.parallelism = vertex_pb.parallelism
|
||||
if vertex_pb.language == Language.PYTHON:
|
||||
operator_bytes = vertex_pb.operator # python operator descriptor
|
||||
if vertex_pb.chained:
|
||||
logger.info("Load chained operator")
|
||||
self.stream_operator = operator.load_chained_operator(
|
||||
operator_bytes)
|
||||
else:
|
||||
logger.info("Load operator")
|
||||
self.stream_operator = operator.load_operator(operator_bytes)
|
||||
self.worker_actor = None
|
||||
if execution_vertex_pb.worker_actor:
|
||||
self.worker_actor = ray.actor.ActorHandle. \
|
||||
_deserialization_helper(execution_vertex_pb.worker_actor)
|
||||
self.container_id = execution_vertex_pb.container_id
|
||||
self.build_time = execution_vertex_pb.build_time
|
||||
self.language = execution_vertex_pb.language
|
||||
self.config = execution_vertex_pb.config
|
||||
self.resource = execution_vertex_pb.resource
|
||||
|
||||
@property
|
||||
def execution_vertex_name(self):
|
||||
return "{}_{}_{}".format(self.execution_job_vertex_id,
|
||||
self.execution_job_vertex_name,
|
||||
self.execution_vertex_id)
|
||||
self.worker_actor = ray.actor.ActorHandle. \
|
||||
_deserialization_helper(vertex_pb.worker_actor)
|
||||
self.container_id = vertex_pb.container_id
|
||||
self.build_time = vertex_pb.build_time
|
||||
self.language = vertex_pb.language
|
||||
self.config = vertex_pb.config
|
||||
self.resource = vertex_pb.resource
|
||||
|
||||
|
||||
class ExecutionVertexContext:
|
||||
actor_id: ActorID
|
||||
execution_vertex: ExecutionVertex
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_vertex_context_pb: remote_call_pb.ExecutionVertexContext
|
||||
):
|
||||
self.execution_vertex = ExecutionVertex(
|
||||
execution_vertex_context_pb.current_execution_vertex)
|
||||
self.job_name = self.execution_vertex.config[Config.STREAMING_JOB_NAME]
|
||||
self.exe_vertex_name = self.execution_vertex.execution_vertex_name
|
||||
self.actor_id = self.execution_vertex.worker_actor._ray_actor_id
|
||||
def __init__(self,
|
||||
vertex_context_pb: remote_call_pb.ExecutionVertexContext):
|
||||
self.execution_vertex = \
|
||||
ExecutionVertex(vertex_context_pb.current_execution_vertex)
|
||||
self.upstream_execution_vertices = [
|
||||
ExecutionVertex(vertex) for vertex in
|
||||
execution_vertex_context_pb.upstream_execution_vertices
|
||||
ExecutionVertex(vertex)
|
||||
for vertex in vertex_context_pb.upstream_execution_vertices
|
||||
]
|
||||
self.downstream_execution_vertices = [
|
||||
ExecutionVertex(vertex) for vertex in
|
||||
execution_vertex_context_pb.downstream_execution_vertices
|
||||
ExecutionVertex(vertex)
|
||||
for vertex in vertex_context_pb.downstream_execution_vertices
|
||||
]
|
||||
self.input_execution_edges = [
|
||||
ExecutionEdge(edge, self.execution_vertex.language)
|
||||
for edge in execution_vertex_context_pb.input_execution_edges
|
||||
for edge in vertex_context_pb.input_execution_edges
|
||||
]
|
||||
self.output_execution_edges = [
|
||||
ExecutionEdge(edge, self.execution_vertex.language)
|
||||
for edge in execution_vertex_context_pb.output_execution_edges
|
||||
for edge in vertex_context_pb.output_execution_edges
|
||||
]
|
||||
|
||||
def get_parallelism(self):
|
||||
@@ -140,16 +112,16 @@ class ExecutionVertexContext:
|
||||
def get_task_id(self):
|
||||
return self.execution_vertex.execution_vertex_id
|
||||
|
||||
def get_source_actor_by_execution_vertex_id(self, execution_vertex_id):
|
||||
for execution_vertex in self.upstream_execution_vertices:
|
||||
if execution_vertex.execution_vertex_id == execution_vertex_id:
|
||||
return execution_vertex.worker_actor
|
||||
raise Exception(
|
||||
"Vertex %s does not exist!".format(execution_vertex_id))
|
||||
def get_source_actor_by_vertex_id(self, execution_vertex_id):
|
||||
for vertex in self.upstream_execution_vertices:
|
||||
if vertex.execution_vertex_id == execution_vertex_id:
|
||||
return vertex.worker_actor
|
||||
raise Exception("ExecutionVertex %s does not exist!"
|
||||
.format(execution_vertex_id))
|
||||
|
||||
def get_target_actor_by_execution_vertex_id(self, execution_vertex_id):
|
||||
for execution_vertex in self.downstream_execution_vertices:
|
||||
if execution_vertex.execution_vertex_id == execution_vertex_id:
|
||||
return execution_vertex.worker_actor
|
||||
raise Exception(
|
||||
"Vertex %s does not exist!".format(execution_vertex_id))
|
||||
def get_target_actor_by_vertex_id(self, execution_vertex_id):
|
||||
for vertex in self.downstream_execution_vertices:
|
||||
if vertex.execution_vertex_id == execution_vertex_id:
|
||||
return vertex.worker_actor
|
||||
raise Exception("ExecutionVertex %s does not exist!"
|
||||
.format(execution_vertex_id))
|
||||
|
||||
@@ -23,14 +23,6 @@ class Processor(ABC):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, checkpoint_obj):
|
||||
pass
|
||||
|
||||
|
||||
class StreamingProcessor(Processor, ABC):
|
||||
"""StreamingProcessor is a process unit for a operator."""
|
||||
@@ -48,13 +40,7 @@ class StreamingProcessor(Processor, ABC):
|
||||
logger.info("Opened Processor {}".format(self))
|
||||
|
||||
def close(self):
|
||||
self.operator.close()
|
||||
|
||||
def save_checkpoint(self):
|
||||
self.operator.save_checkpoint()
|
||||
|
||||
def load_checkpoint(self, checkpoint_obj):
|
||||
self.operator.load_checkpoint(checkpoint_obj)
|
||||
pass
|
||||
|
||||
|
||||
class SourceProcessor(StreamingProcessor):
|
||||
@@ -66,8 +52,8 @@ class SourceProcessor(StreamingProcessor):
|
||||
def process(self, record):
|
||||
raise Exception("SourceProcessor should not process record")
|
||||
|
||||
def fetch(self):
|
||||
self.operator.fetch()
|
||||
def run(self):
|
||||
self.operator.run()
|
||||
|
||||
|
||||
class OneInputProcessor(StreamingProcessor):
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import ray
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.streaming.generated import remote_call_pb2
|
||||
from ray.streaming.runtime.command\
|
||||
import WorkerCommitReport, WorkerRollbackRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CallResult:
|
||||
"""
|
||||
Call Result
|
||||
"""
|
||||
|
||||
def __init__(self, success, result_code, result_msg, result_obj):
|
||||
self.success = success
|
||||
self.result_code = result_code
|
||||
self.result_msg = result_msg
|
||||
self.result_obj = result_obj
|
||||
|
||||
@staticmethod
|
||||
def success(payload=None):
|
||||
return CallResult(True, CallResultEnum.SUCCESS, None, payload)
|
||||
|
||||
@staticmethod
|
||||
def fail(payload=None):
|
||||
return CallResult(False, CallResultEnum.FAILED, None, payload)
|
||||
|
||||
@staticmethod
|
||||
def skipped(msg=None):
|
||||
return CallResult(True, CallResultEnum.SKIPPED, msg, None)
|
||||
|
||||
def is_success(self):
|
||||
if self.result_code is CallResultEnum.SUCCESS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class CallResultEnum(Enum):
|
||||
"""
|
||||
call result enum
|
||||
"""
|
||||
|
||||
SUCCESS = 0
|
||||
FAILED = 1
|
||||
SKIPPED = 2
|
||||
|
||||
|
||||
class RemoteCallMst:
|
||||
"""
|
||||
remote call job master
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def request_job_worker_rollback(master: ActorHandle,
|
||||
request: WorkerRollbackRequest):
|
||||
logger.info("Remote call mst: request job worker rollback start.")
|
||||
request_pb = remote_call_pb2.BaseWorkerCmd()
|
||||
request_pb.actor_id = request.from_actor_id
|
||||
request_pb.timestamp = int(time.time() * 1000.0)
|
||||
rollback_request_pb = remote_call_pb2.WorkerRollbackRequest()
|
||||
rollback_request_pb.exception_msg = request.exception_msg()
|
||||
rollback_request_pb.worker_hostname = os.uname()[1]
|
||||
rollback_request_pb.worker_pid = str(os.getpid())
|
||||
request_pb.detail.Pack(rollback_request_pb)
|
||||
return_ids = master.requestJobWorkerRollback\
|
||||
.remote(request_pb.SerializeToString())
|
||||
result = remote_call_pb2.BoolResult()
|
||||
result.ParseFromString(ray.get(return_ids))
|
||||
logger.info("Remote call mst: request job worker rollback finish.")
|
||||
return result.boolRes
|
||||
|
||||
@staticmethod
|
||||
def report_job_worker_commit(master: ActorHandle,
|
||||
report: WorkerCommitReport):
|
||||
logger.info("Remote call mst: report job worker commit start.")
|
||||
report_pb = remote_call_pb2.BaseWorkerCmd()
|
||||
|
||||
report_pb.actor_id = report.from_actor_id
|
||||
report_pb.timestamp = int(time.time() * 1000.0)
|
||||
wk_commit = remote_call_pb2.WorkerCommitReport()
|
||||
wk_commit.commit_checkpoint_id = report.commit_checkpoint_id
|
||||
report_pb.detail.Pack(wk_commit)
|
||||
return_id = master.reportJobWorkerCommit\
|
||||
.remote(report_pb.SerializeToString())
|
||||
result = remote_call_pb2.BoolResult()
|
||||
result.ParseFromString(ray.get(return_id))
|
||||
logger.info("Remote call mst: report job worker commit finish.")
|
||||
return result.boolRes
|
||||
@@ -3,11 +3,11 @@ import pickle
|
||||
import msgpack
|
||||
from ray.streaming import message
|
||||
|
||||
RECORD_TYPE_ID = 0
|
||||
KEY_RECORD_TYPE_ID = 1
|
||||
CROSS_LANG_TYPE_ID = 0
|
||||
JAVA_TYPE_ID = 1
|
||||
PYTHON_TYPE_ID = 2
|
||||
_RECORD_TYPE_ID = 0
|
||||
_KEY_RECORD_TYPE_ID = 1
|
||||
_CROSS_LANG_TYPE_ID = b"0"
|
||||
_JAVA_TYPE_ID = b"1"
|
||||
_PYTHON_TYPE_ID = b"2"
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
@@ -33,21 +33,21 @@ class CrossLangSerializer(Serializer):
|
||||
|
||||
def serialize(self, obj):
|
||||
if type(obj) is message.Record:
|
||||
fields = [RECORD_TYPE_ID, obj.stream, obj.value]
|
||||
fields = [_RECORD_TYPE_ID, obj.stream, obj.value]
|
||||
elif type(obj) is message.KeyRecord:
|
||||
fields = [KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
|
||||
fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value]
|
||||
else:
|
||||
raise Exception("Unsupported value {}".format(obj))
|
||||
return msgpack.packb(fields, use_bin_type=True)
|
||||
|
||||
def deserialize(self, data):
|
||||
fields = msgpack.unpackb(data, raw=False)
|
||||
if fields[0] == RECORD_TYPE_ID:
|
||||
fields = msgpack.unpackb(data, raw=False, strict_map_key=False)
|
||||
if fields[0] == _RECORD_TYPE_ID:
|
||||
stream, value = fields[1:]
|
||||
record = message.Record(value)
|
||||
record.stream = stream
|
||||
return record
|
||||
elif fields[0] == KEY_RECORD_TYPE_ID:
|
||||
elif fields[0] == _KEY_RECORD_TYPE_ID:
|
||||
stream, key, value = fields[1:]
|
||||
key_record = message.KeyRecord(key, value)
|
||||
key_record.stream = stream
|
||||
|
||||
@@ -1,30 +1,14 @@
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from ray.streaming.collector import OutputCollector
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.context import RuntimeContextImpl
|
||||
from ray.streaming.generated import remote_call_pb2
|
||||
from ray.streaming.runtime import serialization
|
||||
from ray.streaming.runtime.command import WorkerCommitReport
|
||||
from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo
|
||||
from ray.streaming.runtime.remote_call import RemoteCallMst
|
||||
from ray.streaming.runtime.serialization import \
|
||||
PythonSerializer, CrossLangSerializer
|
||||
from ray.streaming.runtime.transfer import CheckpointBarrier
|
||||
from ray.streaming.runtime.transfer import DataMessage
|
||||
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
||||
from ray.streaming.runtime.transfer import ChannelRecoverInfo
|
||||
from ray.streaming.runtime.transfer import ChannelInterruptException
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ray.streaming.runtime.worker import JobWorker
|
||||
from ray.streaming.runtime.processor import Processor, SourceProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,85 +16,18 @@ logger = logging.getLogger(__name__)
|
||||
class StreamTask(ABC):
|
||||
"""Base class for all streaming tasks. Each task runs a processor."""
|
||||
|
||||
def __init__(self, task_id: int, processor: "Processor",
|
||||
worker: "JobWorker", last_checkpoint_id: int):
|
||||
self.worker_context = worker.worker_context
|
||||
self.vertex_context = worker.execution_vertex_context
|
||||
def __init__(self, task_id, processor, worker):
|
||||
self.task_id = task_id
|
||||
self.processor = processor
|
||||
self.worker = worker
|
||||
self.config: dict = worker.config
|
||||
self.reader: Optional[DataReader] = None
|
||||
self.writer: Optional[DataWriter] = None
|
||||
self.is_initial_state = True
|
||||
self.last_checkpoint_id: int = last_checkpoint_id
|
||||
self.config = worker.config
|
||||
self.reader = None # DataReader
|
||||
self.writers = {} # ExecutionEdge -> DataWriter
|
||||
self.thread = None
|
||||
self.prepare_task()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
|
||||
def do_checkpoint(self, checkpoint_id: int, input_points):
|
||||
logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format(
|
||||
checkpoint_id, input_points))
|
||||
|
||||
output_points = None
|
||||
if self.writer is not None:
|
||||
output_points = self.writer.get_output_checkpoints()
|
||||
|
||||
operator_checkpoint = self.processor.save_checkpoint()
|
||||
op_checkpoint_info = OpCheckpointInfo(
|
||||
operator_checkpoint, input_points, output_points, checkpoint_id)
|
||||
self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id)
|
||||
|
||||
barrier_pb = remote_call_pb2.Barrier()
|
||||
barrier_pb.id = checkpoint_id
|
||||
byte_buffer = barrier_pb.SerializeToString()
|
||||
if self.writer is not None:
|
||||
self.writer.broadcast_barrier(checkpoint_id, byte_buffer)
|
||||
logger.info("Operator checkpoint {} finish.".format(checkpoint_id))
|
||||
|
||||
def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id):
|
||||
logger.info(
|
||||
"Start to save cp state and report, checkpoint id is {}.".format(
|
||||
checkpoint_id))
|
||||
self.__save_cp(op_checkpoint_info, checkpoint_id)
|
||||
self.__report_commit(checkpoint_id)
|
||||
self.last_checkpoint_id = checkpoint_id
|
||||
|
||||
def __save_cp(self, op_checkpoint_info, checkpoint_id):
|
||||
logger.info("save operator cp, op_checkpoint_info={}".format(
|
||||
op_checkpoint_info))
|
||||
cp_bytes = pickle.dumps(op_checkpoint_info)
|
||||
self.worker.context_backend.put(
|
||||
self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes)
|
||||
|
||||
def __report_commit(self, checkpoint_id: int):
|
||||
logger.info("Report commit, checkpoint id {}.".format(checkpoint_id))
|
||||
report = WorkerCommitReport(self.vertex_context.actor_id.binary(),
|
||||
checkpoint_id)
|
||||
RemoteCallMst.report_job_worker_commit(self.worker.master_actor,
|
||||
report)
|
||||
|
||||
def clear_expired_cp_state(self, checkpoint_id):
|
||||
cp_key = self.__gen_op_checkpoint_key(checkpoint_id)
|
||||
self.worker.context_backend.remove(cp_key)
|
||||
|
||||
def clear_expired_queue_msg(self, checkpoint_id):
|
||||
# clear operator checkpoint
|
||||
if self.writer is not None:
|
||||
self.writer.clear_checkpoint(checkpoint_id)
|
||||
|
||||
def request_rollback(self, exception_msg: str):
|
||||
self.worker.request_rollback(exception_msg)
|
||||
|
||||
def __gen_op_checkpoint_key(self, checkpoint_id):
|
||||
op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str(
|
||||
self.vertex_context.job_name) + "_" + str(
|
||||
self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id)
|
||||
logger.info(
|
||||
"Generate op checkpoint key {}. ".format(op_checkpoint_key))
|
||||
return op_checkpoint_key
|
||||
|
||||
def prepare_task(self, is_recreate: bool):
|
||||
logger.info(
|
||||
"Preparing stream task, is_recreate={}.".format(is_recreate))
|
||||
def prepare_task(self):
|
||||
channel_conf = dict(self.worker.config)
|
||||
channel_size = int(
|
||||
self.worker.config.get(Config.CHANNEL_SIZE,
|
||||
@@ -122,76 +39,45 @@ class StreamTask(ABC):
|
||||
execution_vertex_context = self.worker.execution_vertex_context
|
||||
build_time = execution_vertex_context.build_time
|
||||
|
||||
# when use memory state, if actor throw exception, will miss state
|
||||
op_checkpoint_info = OpCheckpointInfo()
|
||||
|
||||
cp_bytes = None
|
||||
# get operator checkpoint
|
||||
if is_recreate:
|
||||
cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id)
|
||||
logger.info("Getting task checkpoints from state, "
|
||||
"cpKey={}, checkpointId={}.".format(
|
||||
cp_key, self.last_checkpoint_id))
|
||||
cp_bytes = self.worker.context_backend.get(cp_key)
|
||||
if cp_bytes is None:
|
||||
msg = "Task recover failed, checkpoint is null!"\
|
||||
"cpKey={}".format(cp_key)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if cp_bytes is not None:
|
||||
op_checkpoint_info = pickle.loads(cp_bytes)
|
||||
self.processor.load_checkpoint(op_checkpoint_info.operator_point)
|
||||
logger.info("Stream task recover from checkpoint state,"
|
||||
"checkpoint bytes len={}, checkpointInfo={}.".format(
|
||||
cp_bytes.__len__(), op_checkpoint_info))
|
||||
|
||||
# writers
|
||||
collectors = []
|
||||
output_actors_map = {}
|
||||
for edge in execution_vertex_context.output_execution_edges:
|
||||
target_task_id = edge.target_execution_vertex_id
|
||||
target_actor = execution_vertex_context \
|
||||
.get_target_actor_by_execution_vertex_id(target_task_id)
|
||||
target_actor = execution_vertex_context\
|
||||
.get_target_actor_by_vertex_id(target_task_id)
|
||||
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
||||
build_time)
|
||||
output_actors_map[channel_name] = target_actor
|
||||
|
||||
if len(output_actors_map) > 0:
|
||||
channel_str_ids = list(output_actors_map.keys())
|
||||
target_actors = list(output_actors_map.values())
|
||||
logger.info("Create DataWriter channel_ids {},"
|
||||
"target_actors {}, output_points={}.".format(
|
||||
channel_str_ids, target_actors,
|
||||
op_checkpoint_info.output_points))
|
||||
self.writer = DataWriter(channel_str_ids, target_actors,
|
||||
channel_conf)
|
||||
logger.info("Create DataWriter succeed channel_ids {}, "
|
||||
"target_actors {}.".format(channel_str_ids,
|
||||
target_actors))
|
||||
for edge in execution_vertex_context.output_execution_edges:
|
||||
if len(output_actors_map) > 0:
|
||||
channel_ids = list(output_actors_map.keys())
|
||||
target_actors = list(output_actors_map.values())
|
||||
logger.info(
|
||||
"Create DataWriter channel_ids {}, target_actors {}."
|
||||
.format(channel_ids, target_actors))
|
||||
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
||||
self.writers[edge] = writer
|
||||
collectors.append(
|
||||
OutputCollector(self.writer, channel_str_ids,
|
||||
target_actors, edge.partition))
|
||||
OutputCollector(writer, channel_ids, target_actors,
|
||||
edge.partition))
|
||||
|
||||
# readers
|
||||
input_actor_map = {}
|
||||
for edge in execution_vertex_context.input_execution_edges:
|
||||
source_task_id = edge.source_execution_vertex_id
|
||||
source_actor = execution_vertex_context \
|
||||
.get_source_actor_by_execution_vertex_id(source_task_id)
|
||||
source_actor = execution_vertex_context\
|
||||
.get_source_actor_by_vertex_id(source_task_id)
|
||||
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
|
||||
build_time)
|
||||
input_actor_map[channel_name] = source_actor
|
||||
|
||||
if len(input_actor_map) > 0:
|
||||
channel_str_ids = list(input_actor_map.keys())
|
||||
channel_ids = list(input_actor_map.keys())
|
||||
from_actors = list(input_actor_map.values())
|
||||
logger.info("Create DataReader, channels {},"
|
||||
"input_actors {}, input_points={}.".format(
|
||||
channel_str_ids, from_actors,
|
||||
op_checkpoint_info.input_points))
|
||||
self.reader = DataReader(channel_str_ids, from_actors,
|
||||
channel_conf)
|
||||
logger.info("Create DataReader, channels {}, input_actors {}."
|
||||
.format(channel_ids, from_actors))
|
||||
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
||||
|
||||
def exit_handler():
|
||||
# Make DataReader stop read data when MockQueue destructor
|
||||
@@ -201,31 +87,21 @@ class StreamTask(ABC):
|
||||
import atexit
|
||||
atexit.register(exit_handler)
|
||||
|
||||
# TODO(chaokunyang) add task/job config
|
||||
runtime_context = RuntimeContextImpl(
|
||||
self.worker.task_id,
|
||||
execution_vertex_context.execution_vertex.execution_vertex_index,
|
||||
execution_vertex_context.get_parallelism(),
|
||||
config=channel_conf,
|
||||
job_config=channel_conf)
|
||||
execution_vertex_context.get_parallelism())
|
||||
logger.info("open Processor {}".format(self.processor))
|
||||
self.processor.open(collectors, runtime_context)
|
||||
|
||||
# immediately save cp. In case of FO in cp 0
|
||||
# or use old cp in multi node FO.
|
||||
self.__save_cp(op_checkpoint_info, self.last_checkpoint_id)
|
||||
|
||||
def recover(self, is_recreate: bool):
|
||||
self.prepare_task(is_recreate)
|
||||
|
||||
recover_info = ChannelRecoverInfo()
|
||||
if self.reader is not None:
|
||||
recover_info = self.reader.get_channel_recover_info()
|
||||
@abstractmethod
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
self.thread.start()
|
||||
|
||||
logger.info("Start operator success.")
|
||||
return recover_info
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
@@ -234,24 +110,14 @@ class StreamTask(ABC):
|
||||
def cancel_task(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit_trigger(self, barrier: Barrier) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class InputStreamTask(StreamTask):
|
||||
"""Base class for stream tasks that execute a
|
||||
:class:`runtime.processor.OneInputProcessor` or
|
||||
:class:`runtime.processor.TwoInputProcessor` """
|
||||
|
||||
def commit_trigger(self, barrier):
|
||||
raise RuntimeError(
|
||||
"commit_trigger is only supported in SourceStreamTask.")
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker,
|
||||
last_checkpoint_id):
|
||||
super().__init__(task_id, processor_instance, worker,
|
||||
last_checkpoint_id)
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
self.running = True
|
||||
self.stopped = False
|
||||
self.read_timeout_millis = \
|
||||
@@ -260,58 +126,25 @@ class InputStreamTask(StreamTask):
|
||||
self.python_serializer = PythonSerializer()
|
||||
self.cross_lang_serializer = CrossLangSerializer()
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
logger.info("Input task thread start.")
|
||||
try:
|
||||
while self.running:
|
||||
self.worker.initial_state_lock.acquire()
|
||||
try:
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
self.is_initial_state = False
|
||||
finally:
|
||||
self.worker.initial_state_lock.release()
|
||||
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
if isinstance(item, DataMessage):
|
||||
msg_data = item.body
|
||||
type_id = msg_data[0]
|
||||
if type_id == serialization.PYTHON_TYPE_ID:
|
||||
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||
else:
|
||||
msg = self.cross_lang_serializer.deserialize(
|
||||
msg_data[1:])
|
||||
self.processor.process(msg)
|
||||
elif isinstance(item, CheckpointBarrier):
|
||||
logger.info("Got barrier:{}".format(item))
|
||||
logger.info("Start to do checkpoint {}.".format(
|
||||
item.checkpoint_id))
|
||||
|
||||
input_points = item.get_input_checkpoints()
|
||||
|
||||
self.do_checkpoint(item.checkpoint_id, input_points)
|
||||
logger.info("Do checkpoint {} success.".format(
|
||||
item.checkpoint_id))
|
||||
while self.running:
|
||||
item = self.reader.read(self.read_timeout_millis)
|
||||
if item is not None:
|
||||
msg_data = item.body()
|
||||
type_id = msg_data[:1]
|
||||
if (type_id == serialization._PYTHON_TYPE_ID):
|
||||
msg = self.python_serializer.deserialize(msg_data[1:])
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unknown item type! item={}".format(item))
|
||||
|
||||
except ChannelInterruptException:
|
||||
logger.info("queue has stopped.")
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
"Last success checkpointId={}, now occur error.".format(
|
||||
self.last_checkpoint_id))
|
||||
self.request_rollback(str(e))
|
||||
|
||||
logger.info("Source fetcher thread exit.")
|
||||
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
||||
self.processor.process(msg)
|
||||
self.stopped = True
|
||||
|
||||
def cancel_task(self):
|
||||
self.running = False
|
||||
while not self.stopped:
|
||||
time.sleep(0.5)
|
||||
pass
|
||||
|
||||
|
||||
@@ -319,64 +152,22 @@ class OneInputStreamTask(InputStreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
|
||||
"""
|
||||
|
||||
def __init__(self, task_id, processor_instance, worker,
|
||||
last_checkpoint_id):
|
||||
super().__init__(task_id, processor_instance, worker,
|
||||
last_checkpoint_id)
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
|
||||
class SourceStreamTask(StreamTask):
|
||||
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
|
||||
"""
|
||||
processor: "SourceProcessor"
|
||||
|
||||
def __init__(self, task_id: int, processor_instance: "SourceProcessor",
|
||||
worker: "JobWorker", last_checkpoint_id):
|
||||
super().__init__(task_id, processor_instance, worker,
|
||||
last_checkpoint_id)
|
||||
self.running = True
|
||||
self.stopped = False
|
||||
self.__pending_barrier: Optional[Barrier] = None
|
||||
def __init__(self, task_id, processor_instance, worker):
|
||||
super().__init__(task_id, processor_instance, worker)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
logger.info("Source task thread start.")
|
||||
try:
|
||||
while self.running:
|
||||
self.processor.fetch()
|
||||
# check checkpoint
|
||||
if self.__pending_barrier is not None:
|
||||
# source fetcher only have outputPoints
|
||||
barrier = self.__pending_barrier
|
||||
logger.info("Start to do checkpoint {}.".format(
|
||||
barrier.id))
|
||||
self.do_checkpoint(barrier.id, barrier)
|
||||
logger.info("Finish to do checkpoint {}.".format(
|
||||
barrier.id))
|
||||
self.__pending_barrier = None
|
||||
|
||||
except ChannelInterruptException:
|
||||
logger.info("queue has stopped.")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Last success checkpointId={}, now occur error.".format(
|
||||
self.last_checkpoint_id))
|
||||
self.request_rollback(str(e))
|
||||
|
||||
logger.info("Source fetcher thread exit.")
|
||||
self.stopped = True
|
||||
|
||||
def commit_trigger(self, barrier):
|
||||
if self.__pending_barrier is not None:
|
||||
logger.warning(
|
||||
"Last barrier is not broadcast now, skip this barrier trigger."
|
||||
)
|
||||
return False
|
||||
|
||||
self.__pending_barrier = barrier
|
||||
return True
|
||||
self.processor.run()
|
||||
|
||||
def cancel_task(self):
|
||||
self.running = False
|
||||
while not self.stopped:
|
||||
time.sleep(0.5)
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -2,8 +2,6 @@ import logging
|
||||
import random
|
||||
from queue import Queue
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
@@ -15,7 +13,6 @@ from ray._raylet import PythonFunctionDescriptor
|
||||
from ray._raylet import Language
|
||||
|
||||
CHANNEL_ID_LEN = 20
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChannelID:
|
||||
@@ -100,70 +97,40 @@ def channel_bytes_to_str(id_bytes):
|
||||
return bytes.hex(id_bytes)
|
||||
|
||||
|
||||
class Message(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def body(self):
|
||||
"""Message data"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def timestamp(self):
|
||||
"""Get timestamp when item is written by upstream DataWriter
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def channel_id(self):
|
||||
"""Get string id of channel where data is coming from"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def message_id(self):
|
||||
"""Get message id of the message"""
|
||||
pass
|
||||
|
||||
|
||||
class DataMessage(Message):
|
||||
class DataMessage:
|
||||
"""
|
||||
DataMessage represents data between upstream and downstream operator.
|
||||
DataMessage represents data between upstream and downstream operator
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
body,
|
||||
timestamp,
|
||||
message_id,
|
||||
channel_id,
|
||||
message_id_,
|
||||
is_empty_message=False):
|
||||
self.__body = body
|
||||
self.__timestamp = timestamp
|
||||
self.__channel_id = channel_id
|
||||
self.__message_id = message_id
|
||||
self.__message_id = message_id_
|
||||
self.__is_empty_message = is_empty_message
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__body)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Message data"""
|
||||
return self.__body
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
"""Get timestamp when item is written by upstream DataWriter
|
||||
"""
|
||||
return self.__timestamp
|
||||
|
||||
@property
|
||||
def channel_id(self):
|
||||
"""Get string id of channel where data is coming from
|
||||
"""
|
||||
return self.__channel_id
|
||||
|
||||
@property
|
||||
def message_id(self):
|
||||
return self.__message_id
|
||||
|
||||
@property
|
||||
def is_empty_message(self):
|
||||
"""Whether this message is an empty message.
|
||||
Upstream DataWriter will send an empty message when this is no data
|
||||
@@ -171,47 +138,10 @@ class DataMessage(Message):
|
||||
"""
|
||||
return self.__is_empty_message
|
||||
|
||||
|
||||
class CheckpointBarrier(Message):
|
||||
"""
|
||||
CheckpointBarrier separates the records in the data stream into the set of
|
||||
records that goes into the current snapshot, and the records that go into
|
||||
the next snapshot. Each barrier carries the ID of the snapshot whose
|
||||
records it pushed in front of it.
|
||||
"""
|
||||
|
||||
def __init__(self, barrier_data, timestamp, message_id, channel_id,
|
||||
offsets, barrier_id, barrier_type):
|
||||
self.__barrier_data = barrier_data
|
||||
self.__timestamp = timestamp
|
||||
self.__message_id = message_id
|
||||
self.__channel_id = channel_id
|
||||
self.checkpoint_id = barrier_id
|
||||
self.offsets = offsets
|
||||
self.barrier_type = barrier_type
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.__barrier_data
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
return self.__timestamp
|
||||
|
||||
@property
|
||||
def channel_id(self):
|
||||
return self.__channel_id
|
||||
|
||||
@property
|
||||
def message_id(self):
|
||||
return self.__message_id
|
||||
|
||||
def get_input_checkpoints(self):
|
||||
return self.offsets
|
||||
|
||||
def __str__(self):
|
||||
return "Barrier(Checkpoint id : {})".format(self.checkpoint_id)
|
||||
|
||||
|
||||
class ChannelCreationParametersBuilder:
|
||||
"""
|
||||
@@ -288,6 +218,9 @@ class ChannelCreationParametersBuilder:
|
||||
_python_reader_sync_function_descriptor = sync_function
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataWriter:
|
||||
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
|
||||
to downstream workers
|
||||
@@ -331,26 +264,6 @@ class DataWriter:
|
||||
msg_id = self.writer.write(channel_id.object_qid, item)
|
||||
return msg_id
|
||||
|
||||
def broadcast_barrier(self, checkpoint_id: int, body: bytes):
|
||||
"""Broadcast barriers to all downstream channels
|
||||
Args:
|
||||
checkpoint_id: the checkpoint_id
|
||||
body: barrier payload
|
||||
"""
|
||||
self.writer.broadcast_barrier(checkpoint_id, body)
|
||||
|
||||
def get_output_checkpoints(self) -> List[int]:
|
||||
"""Get output offsets of all downstream channels
|
||||
Returns:
|
||||
a list contains current msg_id of each downstream channel
|
||||
"""
|
||||
return self.writer.get_output_checkpoints()
|
||||
|
||||
def clear_checkpoint(self, checkpoint_id):
|
||||
logger.info("producer start to clear checkpoint, checkpoint_id={}"
|
||||
.format(checkpoint_id))
|
||||
self.writer.clear_checkpoint(checkpoint_id)
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping channel writer.")
|
||||
self.writer.stop()
|
||||
@@ -381,20 +294,18 @@ class DataReader:
|
||||
]
|
||||
creation_parameters = ChannelCreationParametersBuilder()
|
||||
creation_parameters.build_input_queue_parameters(from_actors)
|
||||
py_seq_ids = [0 for _ in range(len(input_channels))]
|
||||
py_msg_ids = [0 for _ in range(len(input_channels))]
|
||||
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
|
||||
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
|
||||
config_bytes = _to_native_conf(conf)
|
||||
self.__queue = Queue(10000)
|
||||
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
|
||||
self.reader, queues_creation_status = _streaming.DataReader.create(
|
||||
self.reader = _streaming.DataReader.create(
|
||||
py_input_channels, creation_parameters.get_parameters(),
|
||||
py_msg_ids, timer_interval, config_bytes, is_mock)
|
||||
|
||||
self.__creation_status = {}
|
||||
for q, status in queues_creation_status.items():
|
||||
self.__creation_status[q] = ChannelCreationStatus(status)
|
||||
logger.info("create DataReader succeed, creation_status={}".format(
|
||||
self.__creation_status))
|
||||
py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes,
|
||||
is_mock)
|
||||
logger.info("create DataReader succeed")
|
||||
|
||||
def read(self, timeout_millis):
|
||||
"""Read data from channel
|
||||
@@ -405,17 +316,16 @@ class DataReader:
|
||||
channel item
|
||||
"""
|
||||
if self.__queue.empty():
|
||||
messages = self.reader.read(timeout_millis)
|
||||
for message in messages:
|
||||
self.__queue.put(message)
|
||||
|
||||
msgs = self.reader.read(timeout_millis)
|
||||
for msg in msgs:
|
||||
msg_bytes, msg_id, timestamp, qid_bytes = msg
|
||||
data_msg = DataMessage(msg_bytes, timestamp,
|
||||
channel_bytes_to_str(qid_bytes), msg_id)
|
||||
self.__queue.put(data_msg)
|
||||
if self.__queue.empty():
|
||||
return None
|
||||
return self.__queue.get()
|
||||
|
||||
def get_channel_recover_info(self):
|
||||
return ChannelRecoverInfo(self.__creation_status)
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping Data Reader.")
|
||||
self.reader.stop()
|
||||
@@ -462,45 +372,3 @@ class ChannelInitException(Exception):
|
||||
class ChannelInterruptException(Exception):
|
||||
def __init__(self, msg=None):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class ChannelRecoverInfo:
|
||||
def __init__(self, queue_creation_status_map=None):
|
||||
if queue_creation_status_map is None:
|
||||
queue_creation_status_map = {}
|
||||
self.__queue_creation_status_map = queue_creation_status_map
|
||||
|
||||
def get_creation_status(self):
|
||||
return self.__queue_creation_status_map
|
||||
|
||||
def get_data_lost_queues(self):
|
||||
data_lost_queues = set()
|
||||
for (q, status) in self.__queue_creation_status_map.items():
|
||||
if status == ChannelCreationStatus.DataLost:
|
||||
data_lost_queues.add(q)
|
||||
return data_lost_queues
|
||||
|
||||
def __str__(self):
|
||||
return "QueueRecoverInfo [dataLostQueues=%s]" \
|
||||
% (self.get_data_lost_queues())
|
||||
|
||||
|
||||
class ChannelCreationStatus(Enum):
|
||||
FreshStarted = 0
|
||||
PullOk = 1
|
||||
Timeout = 2
|
||||
DataLost = 3
|
||||
|
||||
|
||||
def channel_id_bytes_to_str(id_bytes):
|
||||
"""
|
||||
Args:
|
||||
id_bytes: bytes representation of channel id
|
||||
|
||||
Returns:
|
||||
string representation of channel id
|
||||
"""
|
||||
assert type(id_bytes) in [str, bytes]
|
||||
if isinstance(id_bytes, str):
|
||||
return id_bytes
|
||||
return bytes.hex(id_bytes)
|
||||
|
||||
@@ -1,23 +1,12 @@
|
||||
import enum
|
||||
import logging.config
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import ray.streaming.runtime.processor as processor
|
||||
from ray.actor import ActorHandle
|
||||
from ray.streaming.generated import remote_call_pb2
|
||||
from ray.streaming.runtime.command import WorkerRollbackRequest
|
||||
from ray.streaming.runtime.failover import Barrier
|
||||
from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex
|
||||
from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst
|
||||
from ray.streaming.runtime.context_backend import ContextBackendFactory
|
||||
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
||||
from ray.streaming.runtime.transfer import channel_bytes_to_str
|
||||
from ray.streaming.config import Config
|
||||
import ray.streaming._streaming as _streaming
|
||||
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
|
||||
import ray.streaming.runtime.processor as processor
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.runtime.graph import ExecutionVertexContext
|
||||
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,179 +18,74 @@ _NOT_READY_FLAG_ = b" " * 4
|
||||
class JobWorker(object):
|
||||
"""A streaming job worker is used to execute user-defined function and
|
||||
interact with `JobMaster`"""
|
||||
master_actor: Optional[ActorHandle]
|
||||
worker_context: Optional[remote_call_pb2.PythonJobWorkerContext]
|
||||
execution_vertex_context: Optional[ExecutionVertexContext]
|
||||
__need_rollback: bool
|
||||
|
||||
def __init__(self, execution_vertex_pb_bytes):
|
||||
logger.info("Creating job worker, pid={}".format(os.getpid()))
|
||||
execution_vertex_pb = remote_call_pb2\
|
||||
.ExecutionVertexContext.ExecutionVertex()
|
||||
execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes)
|
||||
self.execution_vertex = ExecutionVertex(execution_vertex_pb)
|
||||
self.config = self.execution_vertex.config
|
||||
def __init__(self):
|
||||
self.worker_context = None
|
||||
self.execution_vertex_context = None
|
||||
self.config = None
|
||||
self.task_id = None
|
||||
self.task = None
|
||||
self.stream_processor = None
|
||||
self.master_actor = None
|
||||
self.context_backend = ContextBackendFactory.get_context_backend(
|
||||
self.config)
|
||||
self.initial_state_lock = threading.Lock()
|
||||
self.__rollback_cnt: int = 0
|
||||
self.__is_recreate: bool = False
|
||||
self.__state = WorkerState()
|
||||
self.__need_rollback = True
|
||||
self.reader_client = None
|
||||
self.writer_client = None
|
||||
try:
|
||||
# load checkpoint
|
||||
was_reconstructed = ray.get_runtime_context(
|
||||
).was_current_actor_reconstructed
|
||||
|
||||
logger.info(
|
||||
"Worker was reconstructed: {}".format(was_reconstructed))
|
||||
if was_reconstructed:
|
||||
job_worker_context_key = self.__get_job_worker_context_key()
|
||||
logger.info("Worker get checkpoint state by key: {}.".format(
|
||||
job_worker_context_key))
|
||||
context_bytes = self.context_backend.get(
|
||||
job_worker_context_key)
|
||||
if context_bytes is not None and context_bytes.__len__() > 0:
|
||||
self.init(context_bytes)
|
||||
self.request_rollback(
|
||||
"Python worker recover from checkpoint.")
|
||||
else:
|
||||
logger.error(
|
||||
"Error! Worker get checkpoint state by key {}"
|
||||
" returns None, please check your state backend"
|
||||
", only reliable state backend supports fail-over."
|
||||
.format(job_worker_context_key))
|
||||
except Exception:
|
||||
logger.exception("Error in __init__ of JobWorker")
|
||||
logger.info("Creating job worker succeeded. worker config {}".format(
|
||||
self.config))
|
||||
logger.info("Creating job worker succeeded.")
|
||||
|
||||
def init(self, worker_context_bytes):
|
||||
logger.info("Start to init job worker")
|
||||
try:
|
||||
# deserialize context
|
||||
worker_context = remote_call_pb2.PythonJobWorkerContext()
|
||||
worker_context.ParseFromString(worker_context_bytes)
|
||||
self.worker_context = worker_context
|
||||
self.master_actor = ActorHandle._deserialization_helper(
|
||||
worker_context.master_actor)
|
||||
worker_context = remote_call_pb.PythonJobWorkerContext()
|
||||
worker_context.ParseFromString(worker_context_bytes)
|
||||
self.worker_context = worker_context
|
||||
|
||||
# build vertex context from pb
|
||||
self.execution_vertex_context = ExecutionVertexContext(
|
||||
worker_context.execution_vertex_context)
|
||||
self.execution_vertex = self\
|
||||
.execution_vertex_context.execution_vertex
|
||||
# build vertex context from pb
|
||||
self.execution_vertex_context = ExecutionVertexContext(
|
||||
worker_context.execution_vertex_context)
|
||||
|
||||
# save context
|
||||
job_worker_context_key = self.__get_job_worker_context_key()
|
||||
self.context_backend.put(job_worker_context_key,
|
||||
worker_context_bytes)
|
||||
# use vertex id as task id
|
||||
self.task_id = self.execution_vertex_context.get_task_id()
|
||||
|
||||
# use vertex id as task id
|
||||
self.task_id = self.execution_vertex_context.get_task_id()
|
||||
# build and get processor from operator
|
||||
operator = self.execution_vertex_context.stream_operator
|
||||
self.stream_processor = processor.build_processor(operator)
|
||||
logger.info("Initializing job worker, exe_vertex_name={},"
|
||||
"task_id: {}, operator: {}, pid={}".format(
|
||||
self.execution_vertex_context.exe_vertex_name,
|
||||
self.task_id, self.stream_processor, os.getpid()))
|
||||
# build and get processor from operator
|
||||
operator = self.execution_vertex_context.stream_operator
|
||||
self.stream_processor = processor.build_processor(operator)
|
||||
logger.info(
|
||||
"Initializing job worker, task_id: {}, operator: {}.".format(
|
||||
self.task_id, self.stream_processor))
|
||||
|
||||
# get config from vertex
|
||||
self.config = self.execution_vertex_context.config
|
||||
# get config from vertex
|
||||
self.config = self.execution_vertex_context.config
|
||||
|
||||
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
|
||||
self.reader_client = _streaming.ReaderClient()
|
||||
self.writer_client = _streaming.WriterClient()
|
||||
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
|
||||
self.reader_client = _streaming.ReaderClient()
|
||||
self.writer_client = _streaming.WriterClient()
|
||||
|
||||
logger.info("Job worker init succeeded.")
|
||||
except Exception:
|
||||
logger.exception("Error when init job worker.")
|
||||
return False
|
||||
self.task = self.create_stream_task()
|
||||
|
||||
logger.info("Job worker init succeeded.")
|
||||
return True
|
||||
|
||||
def create_stream_task(self, checkpoint_id):
|
||||
def start(self):
|
||||
self.task.start()
|
||||
logger.info("Job worker start succeeded.")
|
||||
|
||||
def create_stream_task(self):
|
||||
if isinstance(self.stream_processor, processor.SourceProcessor):
|
||||
return SourceStreamTask(self.task_id, self.stream_processor, self,
|
||||
checkpoint_id)
|
||||
return SourceStreamTask(self.task_id, self.stream_processor, self)
|
||||
elif isinstance(self.stream_processor, processor.OneInputProcessor):
|
||||
return OneInputStreamTask(self.task_id, self.stream_processor,
|
||||
self, checkpoint_id)
|
||||
self)
|
||||
else:
|
||||
raise Exception("Unsupported processor type: " +
|
||||
str(type(self.stream_processor)))
|
||||
type(self.stream_processor))
|
||||
|
||||
def rollback(self, checkpoint_id_bytes):
|
||||
checkpoint_id_pb = remote_call_pb2.CheckpointId()
|
||||
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
|
||||
checkpoint_id = checkpoint_id_pb.checkpoint_id
|
||||
|
||||
logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id))
|
||||
|
||||
self.__rollback_cnt += 1
|
||||
if self.__rollback_cnt > 1:
|
||||
self.__is_recreate = True
|
||||
# skip useless rollback
|
||||
self.initial_state_lock.acquire()
|
||||
try:
|
||||
if self.task is not None and self.task.thread.is_alive()\
|
||||
and checkpoint_id == self.task.last_checkpoint_id\
|
||||
and self.task.is_initial_state:
|
||||
logger.info(
|
||||
"Task is already in initial state, skip this rollback.")
|
||||
return self.__gen_call_result(
|
||||
CallResult.skipped(
|
||||
"Task is already in initial state, skip this rollback."
|
||||
))
|
||||
finally:
|
||||
self.initial_state_lock.release()
|
||||
|
||||
# restart task
|
||||
try:
|
||||
if self.task is not None:
|
||||
# make sure the runner is closed
|
||||
self.task.cancel_task()
|
||||
del self.task
|
||||
|
||||
self.task = self.create_stream_task(checkpoint_id)
|
||||
|
||||
q_recover_info = self.task.recover(self.__is_recreate)
|
||||
|
||||
self.__state.set_type(StateType.RUNNING)
|
||||
self.__need_rollback = False
|
||||
|
||||
logger.info(
|
||||
"Rollback success, checkpoint is {}, qRecoverInfo is {}.".
|
||||
format(checkpoint_id, q_recover_info))
|
||||
|
||||
return self.__gen_call_result(CallResult.success(q_recover_info))
|
||||
except Exception:
|
||||
logger.exception("Rollback has exception.")
|
||||
return self.__gen_call_result(CallResult.fail())
|
||||
|
||||
def on_reader_message(self, *buffers):
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""Called by upstream queue writer to send data message to downstream
|
||||
queue reader.
|
||||
"""
|
||||
if self.reader_client is None:
|
||||
logger.info("reader_client is None, skip writer transfer")
|
||||
return
|
||||
self.reader_client.on_reader_message(*buffers)
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""Called by upstream queue writer to send
|
||||
control message to downstream downstream queue reader.
|
||||
"""Called by upstream queue writer to send control message to downstream
|
||||
downstream queue reader.
|
||||
"""
|
||||
if self.reader_client is None:
|
||||
logger.info("task is None, skip reader transfer")
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
@@ -210,9 +94,6 @@ class JobWorker(object):
|
||||
"""Called by downstream queue reader to send notify message to
|
||||
upstream queue writer.
|
||||
"""
|
||||
if self.writer_client is None:
|
||||
logger.info("writer_client is None, skip writer transfer")
|
||||
return
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
@@ -223,164 +104,3 @@ class JobWorker(object):
|
||||
return _NOT_READY_FLAG_
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def shutdown_without_reconstruction(self):
|
||||
logger.info("Python worker shutdown without reconstruction.")
|
||||
ray.actor.exit_actor()
|
||||
|
||||
def notify_checkpoint_timeout(self, checkpoint_id_bytes):
|
||||
pass
|
||||
|
||||
def commit(self, barrier_bytes):
|
||||
barrier_pb = remote_call_pb2.Barrier()
|
||||
barrier_pb.ParseFromString(barrier_bytes)
|
||||
barrier = Barrier(barrier_pb.id)
|
||||
logger.info("Receive trigger, barrier is {}.".format(barrier))
|
||||
|
||||
if self.task is not None:
|
||||
self.task.commit_trigger(barrier)
|
||||
ret = remote_call_pb2.BoolResult()
|
||||
ret.boolRes = True
|
||||
return ret.SerializeToString()
|
||||
|
||||
def clear_expired_cp(self, state_checkpoint_id_bytes,
|
||||
queue_checkpoint_id_bytes):
|
||||
state_checkpoint_id = self.__parse_to_checkpoint_id(
|
||||
state_checkpoint_id_bytes)
|
||||
queue_checkpoint_id = self.__parse_to_checkpoint_id(
|
||||
queue_checkpoint_id_bytes)
|
||||
logger.info("Start to clear expired checkpoint, checkpoint_id={},"
|
||||
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
|
||||
state_checkpoint_id, queue_checkpoint_id,
|
||||
self.execution_vertex_context.exe_vertex_name))
|
||||
|
||||
ret = remote_call_pb2.BoolResult()
|
||||
ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \
|
||||
if state_checkpoint_id > 0 else True
|
||||
ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id)
|
||||
logger.info(
|
||||
"Clear expired checkpoint done, result={}, checkpoint_id={},"
|
||||
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
|
||||
ret.boolRes, state_checkpoint_id, queue_checkpoint_id,
|
||||
self.execution_vertex_context.exe_vertex_name))
|
||||
return ret.SerializeToString()
|
||||
|
||||
def __clear_expired_cp_state(self, checkpoint_id):
|
||||
if self.__need_rollback:
|
||||
logger.warning("Need rollback, skip clear_expired_cp_state"
|
||||
", checkpoint id: {}".format(checkpoint_id))
|
||||
return False
|
||||
|
||||
logger.info("Clear expired checkpoint state, cp id is {}.".format(
|
||||
checkpoint_id))
|
||||
|
||||
if self.task is not None:
|
||||
self.task.clear_expired_cp_state(checkpoint_id)
|
||||
return True
|
||||
|
||||
def __clear_expired_queue_msg(self, checkpoint_id):
|
||||
if self.__need_rollback:
|
||||
logger.warning("Need rollback, skip clear_expired_queue_msg"
|
||||
", checkpoint id: {}".format(checkpoint_id))
|
||||
return False
|
||||
|
||||
logger.info("Clear expired queue msg, checkpoint_id is {}.".format(
|
||||
checkpoint_id))
|
||||
|
||||
if self.task is not None:
|
||||
self.task.clear_expired_queue_msg(checkpoint_id)
|
||||
return True
|
||||
|
||||
def __parse_to_checkpoint_id(self, checkpoint_id_bytes):
|
||||
checkpoint_id_pb = remote_call_pb2.CheckpointId()
|
||||
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
|
||||
return checkpoint_id_pb.checkpoint_id
|
||||
|
||||
def check_if_need_rollback(self):
|
||||
ret = remote_call_pb2.BoolResult()
|
||||
ret.boolRes = self.__need_rollback
|
||||
return ret.SerializeToString()
|
||||
|
||||
def request_rollback(self, exception_msg="Python exception."):
|
||||
logger.info("Request rollback.")
|
||||
|
||||
self.__need_rollback = True
|
||||
self.__is_recreate = True
|
||||
|
||||
request_ret = False
|
||||
for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES):
|
||||
logger.info("request rollback {} time".format(i))
|
||||
try:
|
||||
request_ret = RemoteCallMst.request_job_worker_rollback(
|
||||
self.master_actor,
|
||||
WorkerRollbackRequest(
|
||||
self.execution_vertex_context.actor_id.binary(),
|
||||
"Exception msg=%s, retry time=%d." % (exception_msg,
|
||||
i)))
|
||||
except Exception:
|
||||
logger.exception("Unexpected error when rollback")
|
||||
logger.info("request rollback {} time, ret={}".format(
|
||||
i, request_ret))
|
||||
if not request_ret:
|
||||
logger.warning(
|
||||
"Request rollback return false"
|
||||
", maybe it's invalid request, try to sleep 1s.")
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
if not request_ret:
|
||||
logger.warning("Request failed after retry {} times,"
|
||||
"now worker shutdown without reconstruction."
|
||||
.format(Config.REQUEST_ROLLBACK_RETRY_TIMES))
|
||||
self.shutdown_without_reconstruction()
|
||||
|
||||
self.__state.set_type(StateType.WAIT_ROLLBACK)
|
||||
|
||||
def __gen_call_result(self, call_result):
|
||||
call_result_pb = remote_call_pb2.CallResult()
|
||||
|
||||
call_result_pb.success = call_result.success
|
||||
call_result_pb.result_code = call_result.result_code.value
|
||||
if call_result.result_msg is not None:
|
||||
call_result_pb.result_msg = call_result.result_msg
|
||||
|
||||
if call_result.result_obj is not None:
|
||||
q_recover_info = call_result.result_obj
|
||||
for q, status in q_recover_info.get_creation_status().items():
|
||||
call_result_pb.result_obj.creation_status[channel_bytes_to_str(
|
||||
q)] = status.value
|
||||
|
||||
return call_result_pb.SerializeToString()
|
||||
|
||||
def _gen_unique_key(self, key_prefix):
|
||||
return key_prefix \
|
||||
+ str(self.config.get(Config.STREAMING_JOB_NAME)) \
|
||||
+ "_" + str(self.execution_vertex.execution_vertex_id)
|
||||
|
||||
def __get_job_worker_context_key(self) -> str:
|
||||
return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY)
|
||||
|
||||
|
||||
class WorkerState:
|
||||
"""
|
||||
worker state
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__type = StateType.INIT
|
||||
|
||||
def set_type(self, type):
|
||||
self.__type = type
|
||||
|
||||
def get_type(self):
|
||||
return self.__type
|
||||
|
||||
|
||||
class StateType(enum.Enum):
|
||||
"""
|
||||
state type
|
||||
"""
|
||||
|
||||
INIT = 1
|
||||
RUNNING = 2
|
||||
WAIT_ROLLBACK = 3
|
||||
|
||||
Reference in New Issue
Block a user