mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 01:55:25 +08:00
This reverts commit 1b1466748f.
This commit is contained in:
@@ -3,11 +3,10 @@ import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ray import streaming
|
||||
from ray.streaming import function
|
||||
from ray.streaming import message
|
||||
from ray.streaming.collector import Collector
|
||||
from ray.streaming.collector import CollectionCollector
|
||||
from ray.streaming.function import SourceFunction
|
||||
from ray.streaming.runtime import gateway_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,14 +40,6 @@ class Operator(ABC):
|
||||
def operator_type(self) -> OperatorType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, checkpoint_obj):
|
||||
pass
|
||||
|
||||
|
||||
class OneInputOperator(Operator, ABC):
|
||||
"""Interface for stream operators with one input."""
|
||||
@@ -99,20 +90,8 @@ class StreamOperator(Operator, ABC):
|
||||
for collector in self.collectors:
|
||||
collector.collect(record)
|
||||
|
||||
def save_checkpoint(self):
|
||||
self.func.save_checkpoint()
|
||||
|
||||
def load_checkpoint(self, checkpoint_obj):
|
||||
self.func.load_checkpoint(checkpoint_obj)
|
||||
|
||||
|
||||
class SourceOperator(Operator, ABC):
|
||||
@abstractmethod
|
||||
def fetch(self):
|
||||
pass
|
||||
|
||||
|
||||
class SourceOperatorImpl(SourceOperator, StreamOperator):
|
||||
class SourceOperator(StreamOperator):
|
||||
"""
|
||||
Operator to run a :class:`function.SourceFunction`
|
||||
"""
|
||||
@@ -125,19 +104,19 @@ class SourceOperatorImpl(SourceOperator, StreamOperator):
|
||||
for collector in self.collectors:
|
||||
collector.collect(message.Record(value))
|
||||
|
||||
def __init__(self, func: SourceFunction):
|
||||
def __init__(self, func):
|
||||
assert isinstance(func, function.SourceFunction)
|
||||
super().__init__(func)
|
||||
self.source_context = None
|
||||
|
||||
def open(self, collectors, runtime_context):
|
||||
super().open(collectors, runtime_context)
|
||||
self.source_context = SourceOperatorImpl.SourceContextImpl(collectors)
|
||||
self.source_context = SourceOperator.SourceContextImpl(collectors)
|
||||
self.func.init(runtime_context.get_parallelism(),
|
||||
runtime_context.get_task_index())
|
||||
|
||||
def fetch(self):
|
||||
self.func.fetch(self.source_context)
|
||||
def run(self):
|
||||
self.func.run(self.source_context)
|
||||
|
||||
def operator_type(self):
|
||||
return OperatorType.SOURCE
|
||||
@@ -168,7 +147,8 @@ class FlatMapOperator(StreamOperator, OneInputOperator):
|
||||
|
||||
def open(self, collectors, runtime_context):
|
||||
super().open(collectors, runtime_context)
|
||||
self.collection_collector = CollectionCollector(collectors)
|
||||
self.collection_collector = streaming.collector.CollectionCollector(
|
||||
collectors)
|
||||
|
||||
def process_element(self, record):
|
||||
self.func.flat_map(record.value, self.collection_collector)
|
||||
@@ -306,12 +286,12 @@ class ChainedOperator(StreamOperator, ABC):
|
||||
raise Exception("Current operator type is not supported")
|
||||
|
||||
|
||||
class ChainedSourceOperator(SourceOperator, ChainedOperator):
|
||||
class ChainedSourceOperator(ChainedOperator):
|
||||
def __init__(self, operators, configs):
|
||||
super().__init__(operators, configs)
|
||||
|
||||
def fetch(self):
|
||||
self.operators[0].fetch()
|
||||
def run(self):
|
||||
self.operators[0].run()
|
||||
|
||||
|
||||
class ChainedOneInputOperator(ChainedOperator):
|
||||
@@ -370,7 +350,7 @@ def load_operator(descriptor_operator_bytes: bytes):
|
||||
|
||||
|
||||
_function_to_operator = {
|
||||
function.SourceFunction: SourceOperatorImpl,
|
||||
function.SourceFunction: SourceOperator,
|
||||
function.MapFunction: MapOperator,
|
||||
function.FlatMapFunction: FlatMapOperator,
|
||||
function.FilterFunction: FilterOperator,
|
||||
|
||||
Reference in New Issue
Block a user