[Streaming] Streaming Python API (#6755)

This commit is contained in:
chaokunyang
2020-02-25 10:33:33 +08:00
committed by GitHub
parent 2c1f4fd82c
commit 8b6784de06
71 changed files with 2701 additions and 1928 deletions
-16
View File
@@ -1,16 +0,0 @@
Streaming Library
=================
Dependencies:
Install NetworkX: ``pip install networkx``
Examples:
- simple.py: A simple example with stateless operators and different parallelism per stage.
Run ``python simple.py --input-file toy.txt``
- wordcount.py: A streaming wordcount example with a stateful operator (rolling sum).
Run ``python wordcount.py --titles-file articles.txt``
+3
View File
@@ -1,3 +1,6 @@
# flake8: noqa
# Ray should be imported before streaming
import ray
from ray.streaming.context import StreamingContext
__all__ = ['StreamingContext']
+49
View File
@@ -0,0 +1,49 @@
import logging
import pickle
import typing
from abc import ABC, abstractmethod
from ray.streaming import message
from ray.streaming import partition
from ray.streaming.runtime.transfer import ChannelID, DataWriter
logger = logging.getLogger(__name__)
class Collector(ABC):
"""
The collector that collects data from an upstream operator,
and emits data to downstream operators.
"""
@abstractmethod
def collect(self, record):
pass
class CollectionCollector(Collector):
def __init__(self, collector_list):
self._collector_list = collector_list
def collect(self, value):
for collector in self._collector_list:
collector.collect(message.Record(value))
class OutputCollector(Collector):
def __init__(self, channel_ids: typing.List[str], writer: DataWriter,
partition_func: partition.Partition):
self._channel_ids = [ChannelID(id_str) for id_str in channel_ids]
self._writer = writer
self._partition_func = partition_func
logger.info(
"Create OutputCollector, channel_ids {}, partition_func {}".format(
channel_ids, partition_func))
def collect(self, record):
partitions = self._partition_func.partition(record,
len(self._channel_ids))
serialized_message = pickle.dumps(record)
for partition_index in partitions:
self._writer.write(self._channel_ids[partition_index],
serialized_message)
-279
View File
@@ -1,279 +0,0 @@
import hashlib
import logging
import pickle
import sys
import time
import ray
import ray.streaming.runtime.transfer as transfer
from ray.streaming.config import Config
from ray.streaming.operator import PStrategy
from ray.streaming.runtime.transfer import ChannelID
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Forward and broadcast stream partitioning strategies
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
# Used to choose output channel in case of hash-based shuffling
def _hash(value):
if isinstance(value, int):
return value
try:
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
except AttributeError:
return int(hashlib.sha1(value).hexdigest(), 16)
class DataChannel:
"""A data channel for actor-to-actor communication.
Attributes:
env (Environment): The environment the channel belongs to.
src_operator_id (UUID): The id of the source operator of the channel.
src_instance_index (int): The id of the source instance.
dst_operator_id (UUID): The id of the destination operator of the
channel.
dst_instance_index (int): The id of the destination instance.
"""
def __init__(self, src_operator_id, src_instance_index, dst_operator_id,
dst_instance_index, str_qid):
self.src_operator_id = src_operator_id
self.src_instance_index = src_instance_index
self.dst_operator_id = dst_operator_id
self.dst_instance_index = dst_instance_index
self.str_qid = str_qid
self.qid = ChannelID(str_qid)
def __repr__(self):
return "(src({},{}),dst({},{}), qid({}))".format(
self.src_operator_id, self.src_instance_index,
self.dst_operator_id, self.dst_instance_index, self.str_qid)
_CLOSE_FLAG = b" "
# Pulls and merges data from multiple input channels
class DataInput:
"""An input gate of an operator instance.
The input gate pulls records from all input channels in a round-robin
fashion.
Attributes:
input_channels (list): The list of input channels.
channel_index (int): The index of the next channel to pull from.
max_index (int): The number of input channels.
closed (list): A list of flags indicating whether an input channel
has been marked as 'closed'.
all_closed (bool): Denotes whether all input channels have been
closed (True) or not (False).
"""
def __init__(self, env, channels):
assert len(channels) > 0
self.env = env
self.reader = None # created in `init` method
self.input_channels = channels
self.channel_index = 0
self.max_index = len(channels)
# Tracks the channels that have been closed. qid: close status
self.closed = {}
def init(self):
channels = [c.str_qid for c in self.input_channels]
input_actors = []
for c in self.input_channels:
actor = self.env.execution_graph.get_actor(c.src_operator_id,
c.src_instance_index)
input_actors.append(actor)
logger.info("DataInput input_actors %s", input_actors)
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: self.env.config.channel_type
}
self.reader = transfer.DataReader(channels, input_actors, conf)
def pull(self):
# pull from channel
item = self.reader.read(100)
while item is None:
time.sleep(0.001)
item = self.reader.read(100)
msg_data = item.body()
if msg_data == _CLOSE_FLAG:
self.closed[item.channel_id] = True
if len(self.closed) == len(self.input_channels):
return None
else:
return self.pull()
else:
return pickle.loads(msg_data)
def close(self):
self.reader.stop()
# Selects output channel(s) and pushes data
class DataOutput:
"""An output gate of an operator instance.
The output gate pushes records to output channels according to the
user-defined partitioning scheme.
Attributes:
partitioning_schemes (dict): A mapping from destination operator ids
to partitioning schemes (see: PScheme in operator.py).
forward_channels (list): A list of channels to forward records.
shuffle_channels (list(list)): A list of output channels to shuffle
records grouped by destination operator.
shuffle_key_channels (list(list)): A list of output channels to
shuffle records by a key grouped by destination operator.
shuffle_exists (bool): A flag indicating that there exists at least
one shuffle_channel.
shuffle_key_exists (bool): A flag indicating that there exists at
least one shuffle_key_channel.
"""
def __init__(self, env, channels, partitioning_schemes):
assert len(channels) > 0
self.env = env
self.writer = None # created in `init` method
self.channels = channels
self.key_selector = None
self.round_robin_indexes = [0]
self.partitioning_schemes = partitioning_schemes
# Prepare output -- collect channels by type
self.forward_channels = [] # Forward and broadcast channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.RoundRobin)
self.round_robin_channels = [[]] * slots # RoundRobin channels
self.round_robin_indexes = [-1] * slots
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.Shuffle)
# Flag used to avoid hashing when there is no shuffling
self.shuffle_exists = slots > 0
self.shuffle_channels = [[]] * slots # Shuffle channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.ShuffleByKey)
# Flag used to avoid hashing when there is no shuffling by key
self.shuffle_key_exists = slots > 0
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
# Distinct shuffle destinations
shuffle_destinations = {}
# Distinct shuffle by key destinations
shuffle_by_key_destinations = {}
# Distinct round robin destinations
round_robin_destinations = {}
index_1 = 0
index_2 = 0
index_3 = 0
for channel in channels:
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
strategy = p_scheme.strategy
if strategy in forward_broadcast_strategies:
self.forward_channels.append(channel)
elif strategy == PStrategy.Shuffle:
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
index_1)
self.shuffle_channels[pos].append(channel)
if pos == index_1:
index_1 += 1
elif strategy == PStrategy.ShuffleByKey:
pos = shuffle_by_key_destinations.setdefault(
channel.dst_operator_id, index_2)
self.shuffle_key_channels[pos].append(channel)
if pos == index_2:
index_2 += 1
elif strategy == PStrategy.RoundRobin:
pos = round_robin_destinations.setdefault(
channel.dst_operator_id, index_3)
self.round_robin_channels[pos].append(channel)
if pos == index_3:
index_3 += 1
else: # TODO (john): Add support for other strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
# A KeyedDataStream can only be shuffled by key
assert not (self.shuffle_exists and self.shuffle_key_exists)
def init(self):
"""init DataOutput which creates DataWriter"""
channel_ids = [c.str_qid for c in self.channels]
to_actors = []
for c in self.channels:
actor = self.env.execution_graph.get_actor(c.dst_operator_id,
c.dst_instance_index)
to_actors.append(actor)
logger.info("DataOutput output_actors %s", to_actors)
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: self.env.config.channel_type
}
self.writer = transfer.DataWriter(channel_ids, to_actors, conf)
def close(self):
"""Close the channel (True) by propagating _CLOSE_FLAG
_CLOSE_FLAG is used as special type of record that is propagated from
sources to sink to notify that the end of data in a stream.
"""
for c in self.channels:
self.writer.write(c.qid, _CLOSE_FLAG)
# must ensure DataWriter send None flag to peer actor
self.writer.stop()
def push(self, record):
target_channels = []
# Forward record
for c in self.forward_channels:
logger.debug("[writer] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
# Forward record
index = 0
for channels in self.round_robin_channels:
self.round_robin_indexes[index] += 1
if self.round_robin_indexes[index] == len(channels):
self.round_robin_indexes[index] = 0 # Reset index
c = channels[self.round_robin_indexes[index]]
logger.debug("[writer] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
index += 1
# Hash-based shuffling by key
if self.shuffle_key_exists:
key, _ = record
h = _hash(key)
for channels in self.shuffle_key_channels:
num_instances = len(channels) # Downstream instances
c = channels[h % num_instances]
logger.debug(
"[key_shuffle] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
elif self.shuffle_exists: # Hash-based shuffling per destination
h = _hash(record)
for channels in self.shuffle_channels:
num_instances = len(channels) # Downstream instances
c = channels[h % num_instances]
logger.debug("[shuffle] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
else: # TODO (john): Handle rescaling
pass
msg_data = pickle.dumps(record)
for c in target_channels:
# send data to channel
self.writer.write(c.qid, msg_data)
def push_all(self, records):
for record in records:
self.push(record)
+2 -1
View File
@@ -13,7 +13,8 @@ class Config:
# return from StreamingReader.getBundle if only empty message read in this
# interval.
TIMER_INTERVAL_MS = "timer_interval_ms"
READ_TIMEOUT_MS = "read_timeout_ms"
DEFAULT_READ_TIMEOUT_MS = "10"
STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity"
# write an empty message if there is no data to be written in this
# interval.
+168
View File
@@ -0,0 +1,168 @@
from abc import ABC, abstractmethod
from ray.streaming.datastream import StreamSource
from ray.streaming.function import LocalFileSourceFunction
from ray.streaming.function import CollectionSourceFunction
from ray.streaming.function import SourceFunction
from ray.streaming.runtime.gateway_client import GatewayClient
class StreamingContext:
"""
Main entry point for ray streaming functionality.
A StreamingContext is also a wrapper of java
`org.ray.streaming.api.context.StreamingContext`
"""
class Builder:
def __init__(self):
self._options = {}
def option(self, key=None, value=None, conf=None):
"""
Sets a config option. Options set using this method are
automatically propagated to :class:`StreamingContext`'s own
configuration.
Args:
key: a key name string for configuration property
value: a value string for configuration property
conf: multi key-value pairs as a dict
Returns:
self
"""
if key is not None:
assert value
self._options[key] = str(value)
if conf is not None:
for k, v in conf.items():
self._options[k] = v
return self
def build(self):
"""
Creates a StreamingContext based on the options set in this
builder.
"""
ctx = StreamingContext()
ctx._gateway_client.with_config(self._options)
return ctx
def __init__(self):
self.__gateway_client = GatewayClient()
self._j_ctx = self._gateway_client.create_streaming_context()
def source(self, source_func: SourceFunction):
"""Create an input data stream with a SourceFunction
Args:
source_func: the SourceFunction used to create the data stream
Returns:
The data stream constructed from the source_func
"""
return StreamSource.build_source(self, source_func)
def from_values(self, *values):
"""Creates a data stream from values
Args:
values: The elements to create the data stream from.
Returns:
The data stream representing the given values
"""
return self.from_collection(values)
def from_collection(self, values):
"""Creates a data stream from the given non-empty collection.
Args:
values: The collection of elements to create the data stream from.
Returns:
The data stream representing the given collection.
"""
assert values, "values shouldn't be None or empty"
func = CollectionSourceFunction(values)
return self.source(func)
def read_text_file(self, filename: str):
"""Reads the given file line-by-line and creates a data stream that
contains a string with the contents of each such line."""
func = LocalFileSourceFunction(filename)
return self.source(func)
def submit(self, job_name: str):
"""Submit job for execution.
Args:
job_name: name of the job
Returns:
An JobSubmissionResult future
"""
self._gateway_client.execute(job_name)
# TODO return a JobSubmissionResult future
def execute(self, job_name: str):
"""Execute the job. This method will block until job finished.
Args:
job_name: name of the job
"""
# TODO support block to job finish
# job_submit_result = self.submit(job_name)
# job_submit_result.wait_finish()
raise Exception("Unsupported")
@property
def _gateway_client(self):
return self.__gateway_client
class RuntimeContext(ABC):
@abstractmethod
def get_task_id(self):
"""
Returns:
Task id of the parallel task.
"""
pass
@abstractmethod
def get_task_index(self):
"""
Gets the index of this parallel subtask. The index starts from 0
and goes up to parallelism-1 (parallelism as returned by
`get_parallelism()`).
Returns:
The index of the parallel subtask.
"""
pass
@abstractmethod
def get_parallelism(self):
"""
Returns:
The parallelism with which the parallel task runs.
"""
pass
class RuntimeContextImpl(RuntimeContext):
def __init__(self, task_id, task_index, parallelism):
self.task_id = task_id
self.task_index = task_index
self.parallelism = parallelism
def get_task_id(self):
return self.task_id
def get_task_index(self):
return self.task_index
def get_parallelism(self):
return self.parallelism
+284
View File
@@ -0,0 +1,284 @@
from abc import ABC
from ray.streaming import function
from ray.streaming import partition
class Stream(ABC):
"""
Abstract base class of all stream types. A Stream represents a stream of
elements of the same type. A Stream can be transformed into another Stream
by applying a transformation.
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
self.input_stream = input_stream
self._j_stream = j_stream
if streaming_context is None:
assert input_stream is not None
self.streaming_context = input_stream.streaming_context
else:
self.streaming_context = streaming_context
self.parallelism = 1
def get_streaming_context(self):
return self.streaming_context
def get_parallelism(self):
"""
Returns:
the parallelism of this transformation
"""
return self.parallelism
def set_parallelism(self, parallelism: int):
"""Sets the parallelism of this transformation
Args:
parallelism: The new parallelism to set on this transformation
Returns:
self
"""
self.parallelism = parallelism
self._gateway_client(). \
call_method(self._j_stream, "setParallelism", parallelism)
return self
def get_input_stream(self):
"""
Returns:
input stream of this stream
"""
return self.input_stream
def get_id(self):
"""
Returns:
An unique id identifies this stream.
"""
return self._gateway_client(). \
call_method(self._j_stream, "getId")
def _gateway_client(self):
return self.get_streaming_context()._gateway_client
class DataStream(Stream):
"""
Represents a stream of data which applies a transformation executed by
python. It's also a wrapper of java
`org.ray.streaming.python.stream.PythonDataStream`
"""
def __init__(self, input_stream, j_stream, streaming_context=None):
super().__init__(
input_stream, j_stream, streaming_context=streaming_context)
def map(self, func):
"""
Applies a Map transformation on a :class:`DataStream`.
The transformation calls a :class:`ray.streaming.function.MapFunction`
for each element of the DataStream.
Args:
func: The MapFunction that is called for each element of the
DataStream. If `func` is a python function instead of a subclass
of MapFunction, it will be wrapped as SimpleMapFunction.
Returns:
A new data stream transformed by the MapFunction.
"""
if not isinstance(func, function.MapFunction):
func = function.SimpleMapFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "map", j_func)
return DataStream(self, j_stream)
def flat_map(self, func):
"""
Applies a FlatMap transformation on a :class:`DataStream`. The
transformation calls a :class:`ray.streaming.function.FlatMapFunction`
for each element of the DataStream.
Each FlatMapFunction call can return any number of elements including
none.
Args:
func: The FlatMapFunction that is called for each element of the
DataStream. If `func` is a python function instead of a subclass
of FlatMapFunction, it will be wrapped as SimpleFlatMapFunction.
Returns:
The transformed DataStream
"""
if not isinstance(func, function.FlatMapFunction):
func = function.SimpleFlatMapFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "flatMap", j_func)
return DataStream(self, j_stream)
def filter(self, func):
"""
Applies a Filter transformation on a :class:`DataStream`. The
transformation calls a :class:`ray.streaming.function.FilterFunction`
for each element of the DataStream.
DataStream and retains only those element for which the function
returns True.
Args:
func: The FilterFunction that is called for each element of the
DataStream. If `func` is a python function instead of a subclass of
FilterFunction, it will be wrapped as SimpleFilterFunction.
Returns:
The filtered DataStream
"""
if not isinstance(func, function.FilterFunction):
func = function.SimpleFilterFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "filter", j_func)
return DataStream(self, j_stream)
def key_by(self, func):
"""
Creates a new :class:`KeyDataStream` that uses the provided key to
partition data stream by key.
Args:
func: The KeyFunction that is used for extracting the key for
partitioning. If `func` is a python function instead of a subclass
of KeyFunction, it will be wrapped as SimpleKeyFunction.
Returns:
A KeyDataStream
"""
if not isinstance(func, function.KeyFunction):
func = function.SimpleKeyFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "keyBy", j_func)
return KeyDataStream(self, j_stream)
def broadcast(self):
"""
Sets the partitioning of the :class:`DataStream` so that the output
elements are broadcast to every parallel instance of the next
operation.
Returns:
The DataStream with broadcast partitioning set.
"""
self._gateway_client().call_method(self._j_stream, "broadcast")
return self
def partition_by(self, partition_func):
"""
Sets the partitioning of the :class:`DataStream` so that the elements
of stream are partitioned by specified partition function.
Args:
partition_func: partition function.
If `func` is a python function instead of a subclass of Partition,
it will be wrapped as SimplePartition.
Returns:
The DataStream with specified partitioning set.
"""
if not isinstance(partition_func, partition.Partition):
partition_func = partition.SimplePartition(partition_func)
j_partition = self._gateway_client().create_py_func(
partition.serialize(partition_func))
self._gateway_client(). \
call_method(self._j_stream, "partitionBy", j_partition)
return self
def sink(self, func):
"""
Create a StreamSink with the given sink.
Args:
func: sink function.
Returns:
a StreamSink.
"""
if not isinstance(func, function.SinkFunction):
func = function.SimpleSinkFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "sink", j_func)
return StreamSink(self, j_stream, func)
class KeyDataStream(Stream):
"""Represents a DataStream returned by a key-by operation.
Wrapper of java org.ray.streaming.python.stream.PythonKeyDataStream
"""
def __init__(self, input_stream, j_stream):
super().__init__(input_stream, j_stream)
def reduce(self, func):
"""
Applies a reduce transformation on the grouped data stream grouped on
by the given key function.
The :class:`ray.streaming.function.ReduceFunction` will receive input
values based on the key value. Only input values with the same key will
go to the same reducer.
Args:
func: The ReduceFunction that will be called for every element of
the input values with the same key. If `func` is a python function
instead of a subclass of ReduceFunction, it will be wrapped as
SimpleReduceFunction.
Returns:
A transformed DataStream.
"""
if not isinstance(func, function.ReduceFunction):
func = function.SimpleReduceFunction(func)
j_func = self._gateway_client().create_py_func(
function.serialize(func))
j_stream = self._gateway_client(). \
call_method(self._j_stream, "reduce", j_func)
return DataStream(self, j_stream)
class StreamSource(DataStream):
"""Represents a source of the DataStream.
Wrapper of java org.ray.streaming.python.stream.PythonStreamSource
"""
def __init__(self, j_stream, streaming_context, source_func):
super().__init__(None, j_stream, streaming_context=streaming_context)
self.source_func = source_func
@staticmethod
def build_source(streaming_context, func):
"""Build a StreamSource source from a collection.
Args:
streaming_context: Stream context
func: A instance of `SourceFunction`
Returns:
A StreamSource
"""
j_stream = streaming_context._gateway_client. \
create_py_stream_source(function.serialize(func))
return StreamSource(j_stream, streaming_context, func)
class StreamSink(Stream):
"""Represents a sink of the DataStream.
Wrapper of java org.ray.streaming.python.stream.PythonStreamSink
"""
def __init__(self, input_stream, j_stream, func):
super().__init__(input_stream, j_stream)
@@ -1,67 +0,0 @@
import argparse
import logging
import time
import ray
from ray.streaming.streaming import Environment
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--input-file", required=True, help="the input text file")
# A class used to check attribute-based key selection
class Record:
def __init__(self, record):
k, _ = record
self.word = k
self.record = record
# Splits input line into words and outputs objects of type Record
# each one consisting of a key (word) and a tuple (word,1)
def splitter(line):
records = []
words = line.split()
for w in words:
records.append(Record((w, 1)))
return records
# Receives an object of type Record and returns the actual tuple
def as_tuple(record):
return record.record
if __name__ == "__main__":
# Get program parameters
args = parser.parse_args()
input_file = str(args.input_file)
ray.init()
ray.register_custom_serializer(Record, use_dict=True)
# A Ray streaming environment with the default configuration
env = Environment()
env.set_parallelism(2) # Each operator will be executed by two actors
# 'key_by("word")' physically partitions the stream of records
# based on the hash value of the 'word' attribute (see Record class above)
# 'map(as_tuple)' maps a record of type Record into a tuple
# 'sum(1)' sums the 2nd element of the tuple, i.e. the word count
stream = env.read_text_file(input_file) \
.round_robin() \
.flat_map(splitter) \
.key_by("word") \
.map(as_tuple) \
.sum(1) \
.inspect(print) # Prints the content of the
# stream to stdout
start = time.time()
env_handle = env.execute() # Deploys and executes the dataflow
ray.get(env_handle) # Stay alive until execution finishes
end = time.time()
logger.info("Elapsed time: {} secs".format(end - start))
logger.debug("Output stream id: {}".format(stream.id))
-52
View File
@@ -1,52 +0,0 @@
import argparse
import logging
import time
import ray
from ray.streaming.config import Config
from ray.streaming.streaming import Environment, Conf
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--input-file", required=True, help="the input text file")
# Test functions
def splitter(line):
return line.split()
def filter_fn(word):
if "f" in word:
return True
return False
if __name__ == "__main__":
args = parser.parse_args()
ray.init(local_mode=False)
# A Ray streaming environment with the default configuration
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
# Stream represents the ouput of the filter and
# can be forked into other dataflows
stream = env.read_text_file(args.input_file) \
.shuffle() \
.flat_map(splitter) \
.set_parallelism(2) \
.filter(filter_fn) \
.set_parallelism(2) \
.inspect(lambda x: print("result", x)) # Prints the contents of the
# stream to stdout
start = time.time()
env_handle = env.execute()
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
end = time.time()
logger.info("Elapsed time: {} secs".format(end - start))
logger.debug("Output stream id: {}".format(stream.id))
-5
View File
@@ -1,5 +0,0 @@
This is
a test file
to test if example
works
fine
+17 -35
View File
@@ -4,7 +4,8 @@ import time
import ray
import wikipedia
from ray.streaming.streaming import Environment
from ray.streaming import StreamingContext
from ray.streaming.config import Config
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@@ -23,7 +24,6 @@ class Wikipedia:
def __init__(self, title_file):
# Titles in this file will be as queries
self.title_file = title_file
# TODO (john): Handle possible exception here
self.title_reader = iter(list(open(self.title_file, "r").readlines()))
self.done = False
self.article_done = True
@@ -57,21 +57,7 @@ class Wikipedia:
# Splits input line into words and
# outputs records of the form (word,1)
def splitter(line):
records = []
words = line.split()
for w in words:
records.append((w, 1))
return records
# Returns the first attribute of a tuple
def key_selector(tuple):
return tuple[0]
# Returns the second attribute of a tuple
def attribute_selector(tuple):
return tuple[1]
return [(word, 1) for word in line.split()]
if __name__ == "__main__":
@@ -79,27 +65,23 @@ if __name__ == "__main__":
args = parser.parse_args()
titles_file = str(args.titles_file)
ray.init()
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder() \
.option(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) \
.build()
# A Ray streaming environment with the default configuration
env = Environment()
env.set_parallelism(2) # Each operator will be executed by two actors
ctx.set_parallelism(1) # Each operator will be executed by two actors
# The following dataflow is a simple streaming wordcount
# with a rolling sum operator.
# It reads articles from wikipedia, splits them in words,
# shuffles words, and counts the occurences of each word.
stream = env.source(Wikipedia(titles_file)) \
.round_robin() \
.flat_map(splitter) \
.key_by(key_selector) \
.sum(attribute_selector) \
.inspect(print) # Prints the contents of the
# stream to stdout
# Reads articles from wikipedia, splits them in words,
# shuffles words, and counts the occurrences of each word.
stream = ctx.source(Wikipedia(titles_file)) \
.flat_map(splitter) \
.key_by(lambda x: x[0]) \
.reduce(lambda old_value, new_value:
(old_value[0], old_value[1] + new_value[1])) \
.sink(print)
start = time.time()
env_handle = env.execute() # Deploys and executes the dataflow
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
ctx.execute("wordcount")
end = time.time()
logger.info("Elapsed time: {} secs".format(end - start))
logger.debug("Output stream id: {}".format(stream.id))
+315
View File
@@ -0,0 +1,315 @@
import importlib
import inspect
import sys
from abc import ABC, abstractmethod
import typing
import cloudpickle
from ray.streaming.runtime import gateway_client
class Function(ABC):
"""The base interface for all user-defined functions."""
def open(self, conf: typing.Dict[str, str]):
pass
def close(self):
pass
class SourceContext(ABC):
"""
Interface that source functions use to emit elements, and possibly
watermarks."""
@abstractmethod
def collect(self, element):
"""Emits one element from the source, without attaching a timestamp."""
pass
class SourceFunction(Function):
"""Interface of Source functions."""
@abstractmethod
def init(self, parallel, index):
"""
Args:
parallel: parallelism of source function
index: task index of this function and goes up from 0 to
parallel-1.
"""
pass
@abstractmethod
def run(self, ctx: SourceContext):
"""Starts the source. Implementations can use the
:class:`SourceContext` to emit elements.
"""
pass
def close(self):
pass
class MapFunction(Function):
"""
Base interface for Map functions. Map functions take elements and transform
them element wise. A Map function always produces a single result element
for each input element.
"""
def map(self, value):
pass
class FlatMapFunction(Function):
"""
Base interface for flatMap functions. FlatMap functions take elements and
transform them into zero, one, or more elements.
"""
def flat_map(self, value, collector):
"""Takes an element from the input data set and transforms it into zero,
one, or more elements.
Args:
value: The input value.
collector: The collector for returning result values.
"""
pass
class FilterFunction(Function):
"""
A filter function is a predicate applied individually to each record.
The predicate decides whether to keep the element, or to discard it.
"""
def filter(self, value):
"""The filter function that evaluates the predicate.
Args:
value: The value to be filtered.
Returns:
True for values that should be retained, false for values to be
filtered out.
"""
pass
class KeyFunction(Function):
"""
A key function is extractor which takes an object and returns the
deterministic key for that object.
"""
def key_by(self, value):
"""User-defined function that deterministically extracts the key from
an object.
Args:
value: The object to get the key from.
Returns:
The extracted key.
"""
pass
class ReduceFunction(Function):
"""
Base interface for Reduce functions. Reduce functions combine groups of
elements to a single value, by taking always two elements and combining
them into one.
"""
def reduce(self, old_value, new_value):
"""
The core method of ReduceFunction, combining two values into one value
of the same type. The reduce function is consecutively applied to all
values of a group until only a single value remains.
Args:
old_value: The old value to combine.
new_value: The new input value to combine.
Returns:
The combined value of both values.
"""
pass
class SinkFunction(Function):
"""Interface for implementing user defined sink functionality."""
def sink(self, value):
"""Writes the given value to the sink. This function is called for
every record."""
pass
class CollectionSourceFunction(SourceFunction):
def __init__(self, values):
self.values = values
def init(self, parallel, index):
pass
def run(self, ctx: SourceContext):
for v in self.values:
ctx.collect(v)
class LocalFileSourceFunction(SourceFunction):
def __init__(self, filename):
self.filename = filename
def init(self, parallel, index):
pass
def run(self, ctx: SourceContext):
with open(self.filename, "r") as f:
line = f.readline()
while line != "":
ctx.collect(line[:-1])
line = f.readline()
class SimpleMapFunction(MapFunction):
def __init__(self, func):
self.func = func
def map(self, value):
return self.func(value)
class SimpleFlatMapFunction(FlatMapFunction):
"""
Wrap a python function as :class:`FlatMapFunction`
>>> assert SimpleFlatMapFunction(lambda x: x.split())
>>> def flat_func(x, collector):
... for item in x.split():
... collector.collect(item)
>>> assert SimpleFlatMapFunction(flat_func)
"""
def __init__(self, func):
"""
Args:
func: a python function which takes an element from input augment
and transforms it into zero, one, or more elements.
Or takes an element from input augment, and used provided collector
to collect zero, one, or more elements.
"""
self.func = func
self.process_func = None
sig = inspect.signature(func)
assert len(sig.parameters) <= 2,\
"func should receive value [, collector] as arguments"
if len(sig.parameters) == 2:
def process(value, collector):
func(value, collector)
self.process_func = process
else:
def process(value, collector):
for elem in func(value):
collector.collect(elem)
self.process_func = process
def flat_map(self, value, collector):
self.process_func(value, collector)
class SimpleFilterFunction(FilterFunction):
def __init__(self, func):
self.func = func
def filter(self, value):
return self.func(value)
class SimpleKeyFunction(KeyFunction):
def __init__(self, func):
self.func = func
def key_by(self, value):
return self.func(value)
class SimpleReduceFunction(ReduceFunction):
def __init__(self, func):
self.func = func
def reduce(self, old_value, new_value):
return self.func(old_value, new_value)
class SimpleSinkFunction(SinkFunction):
def __init__(self, func):
self.func = func
def sink(self, value):
return self.func(value)
def serialize(func: Function):
"""Serialize a streaming :class:`Function`"""
return cloudpickle.dumps(func)
def deserialize(func_bytes):
"""Deserialize a binary function serialized by `serialize` method."""
return cloudpickle.loads(func_bytes)
def load_function(descriptor_func_bytes: bytes):
"""
Deserialize `descriptor_func_bytes` to get function info, then
get or load streaming function.
Note that this function must be kept in sync with
`org.ray.streaming.runtime.python.GraphPbBuilder.serializeFunction`
Args:
descriptor_func_bytes: serialized function info
Returns:
a streaming function
"""
function_bytes, module_name, class_name, function_name, function_interface\
= gateway_client.deserialize(descriptor_func_bytes)
if function_bytes:
return deserialize(function_bytes)
else:
assert module_name
assert function_interface
function_interface = getattr(sys.modules[__name__], function_interface)
mod = importlib.import_module(module_name)
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
assert issubclass(cls, function_interface)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
simple_func_class = _get_simple_function_class(function_interface)
return simple_func_class(func)
def _get_simple_function_class(function_interface):
"""Get the wrapper function for the given `function_interface`."""
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and issubclass(obj, function_interface):
if obj is not function_interface and obj.__name__.startswith(
"Simple"):
return obj
raise Exception(
"SimpleFunction for %s doesn't exist".format(function_interface))
+3 -3
View File
@@ -155,7 +155,7 @@ cdef class DataWriter:
ctx.get().MarkMockTest()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
channel_logger.info("DataWriter load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
c_writer = new CDataWriter(ctx)
cdef:
@@ -235,7 +235,7 @@ cdef class DataReader:
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
channel_logger.info("DataReader load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
if is_mock:
ctx.get().MarkMockTest()
@@ -289,7 +289,7 @@ cdef class DataReader:
msg_id = msg.get().GetMessageSeqId()
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
return msgs
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
return []
else:
raise Exception("Unsupported bundle type {}".format(bundle_type))
-120
View File
@@ -1,120 +0,0 @@
import logging
import pickle
import threading
import ray
import ray.streaming._streaming as _streaming
from ray.streaming.config import Config
from ray._raylet import PythonFunctionDescriptor
from ray.streaming.communication import DataInput, DataOutput
logger = logging.getLogger(__name__)
@ray.remote
class JobWorker:
"""A streaming job worker.
Attributes:
worker_id: The id of the instance.
input_channels: The input gate that manages input channels of
the instance (see: DataInput in communication.py).
output_channels (DataOutput): The output gate that manages output
channels of the instance (see: DataOutput in communication.py).
the operator instance.
"""
def __init__(self, worker_id, operator, input_channels, output_channels):
self.env = None
self.worker_id = worker_id
self.operator = operator
processor_name = operator.processor_class.__name__
processor_instance = operator.processor_class(operator)
self.processor_name = processor_name
self.processor_instance = processor_instance
self.input_channels = input_channels
self.output_channels = output_channels
self.input_gate = None
self.output_gate = None
self.reader_client = None
self.writer_client = None
def init(self, env):
"""init streaming actor"""
env = pickle.loads(env)
self.env = env
logger.info("init operator instance %s", self.processor_name)
if env.config.channel_type == Config.NATIVE_CHANNEL:
core_worker = ray.worker.global_worker.core_worker
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__,
self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__,
self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
if len(self.input_channels) > 0:
self.input_gate = DataInput(env, self.input_channels)
self.input_gate.init()
if len(self.output_channels) > 0:
self.output_gate = DataOutput(
env, self.output_channels,
self.operator.partitioning_strategies)
self.output_gate.init()
logger.info("init operator instance %s succeed", self.processor_name)
return True
# Starts the actor
def start(self):
self.t = threading.Thread(target=self.run, daemon=True)
self.t.start()
actor_id = ray.worker.global_worker.actor_id
logger.info("%s %s started, actor id %s", self.__class__.__name__,
self.processor_name, actor_id)
def run(self):
logger.info("%s start running", self.processor_name)
self.processor_instance.run(self.input_gate, self.output_gate)
logger.info("%s finished running", self.processor_name)
self.close()
def close(self):
if self.input_gate:
self.input_gate.close()
if self.output_gate:
self.output_gate.close()
def is_finished(self):
return not self.t.is_alive()
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
+17
View File
@@ -0,0 +1,17 @@
class Record:
"""Data record in data stream"""
def __init__(self, value):
self.value = value
self.stream = None
def __repr__(self):
return "Record(%s)".format(self.value)
class KeyRecord(Record):
"""Data record in a keyed data stream"""
def __init__(self, key, value):
super().__init__(value)
self.key = key
+229 -95
View File
@@ -1,109 +1,243 @@
from abc import ABC, abstractmethod
import enum
import logging
import cloudpickle
logger = logging.getLogger(__name__)
logger.setLevel("DEBUG")
from ray import streaming
from ray.streaming import function
from ray.streaming import message
# Stream partitioning schemes
class PScheme:
def __init__(self, strategy, partition_fn=None):
self.strategy = strategy
self.partition_fn = partition_fn
def __repr__(self):
return "({},{})".format(self.strategy, self.partition_fn)
class OperatorType(enum.Enum):
SOURCE = 0 # Sources are where your program reads its input from
ONE_INPUT = 1 # This operator has one data stream as it's input stream.
TWO_INPUT = 2 # This operator has two data stream as it's input stream.
# Partitioning strategies
class PStrategy(enum.Enum):
Forward = 0 # Default
Shuffle = 1
Rescale = 2
RoundRobin = 3
Broadcast = 4
Custom = 5
ShuffleByKey = 6
# ...
class Operator(ABC):
"""
Abstract base class for all operators.
An operator is used to run a :class:`function.Function`.
"""
@abstractmethod
def open(self, collectors, runtime_context):
pass
@abstractmethod
def finish(self):
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def operator_type(self) -> OperatorType:
pass
# Operator types
class OpType(enum.Enum):
Source = 0
Map = 1
FlatMap = 2
Filter = 3
TimeWindow = 4
KeyBy = 5
Sink = 6
WindowJoin = 7
Inspect = 8
ReadTextFile = 9
Reduce = 10
Sum = 11
# ...
class OneInputOperator(Operator, ABC):
"""Interface for stream operators with one input."""
@abstractmethod
def process_element(self, record):
pass
def operator_type(self):
return OperatorType.ONE_INPUT
# A logical dataflow operator
class Operator:
def __init__(self,
id,
op_type,
processor_class,
name="",
logic=None,
num_instances=1,
other=None,
state_actor=None):
self.id = id
self.type = op_type
self.processor_class = processor_class
self.name = name
self._logic = cloudpickle.dumps(logic) # The operator's logic
self.num_instances = num_instances
# One partitioning strategy per downstream operator (default: forward)
self.partitioning_strategies = {}
self.other_args = other # Depends on the type of the operator
self.state_actor = state_actor # Actor to query state
class TwoInputOperator(Operator, ABC):
"""Interface for stream operators with two input"""
# Sets the partitioning scheme for an output stream of the operator
def _set_partition_strategy(self,
stream_id,
partitioning_scheme,
dest_operator=None):
self.partitioning_strategies[stream_id] = (partitioning_scheme,
dest_operator)
@abstractmethod
def process_element(self, record1, record2):
pass
# Retrieves the partitioning scheme for the given
# output stream of the operator
# Returns None is no strategy has been defined for the particular stream
def _get_partition_strategy(self, stream_id):
return self.partitioning_strategies.get(stream_id)
def operator_type(self):
return OperatorType.TWO_INPUT
# Cleans metatada from all partitioning strategies that lack a
# destination operator
# Valid entries are re-organized as
# 'destination operator id -> partitioning scheme'
# Should be called only after the logical dataflow has been constructed
def _clean(self):
strategies = {}
for _, v in self.partitioning_strategies.items():
strategy, destination_operator = v
if destination_operator is not None:
strategies.setdefault(destination_operator, strategy)
self.partitioning_strategies = strategies
def print(self):
log = "Operator<\nID = {}\nName = {}\nprocessor_class = {}\n"
log += "Logic = {}\nNumber_of_Instances = {}\n"
log += "Partitioning_Scheme = {}\nOther_Args = {}>\n"
logger.debug(
log.format(self.id, self.name, self.processor_class, self.logic,
self.num_instances, self.partitioning_strategies,
self.other_args))
class StreamOperator(Operator, ABC):
"""
Basic interface for stream operators. Implementers would implement one of
:class:`OneInputOperator` or :class:`TwoInputOperator` to to create
operators that process elements.
"""
@property
def logic(self):
return cloudpickle.loads(self._logic)
def __init__(self, func):
self.func = func
self.collectors = None
self.runtime_context = None
def open(self, collectors, runtime_context):
self.collectors = collectors
self.runtime_context = runtime_context
def finish(self):
pass
def close(self):
pass
def collect(self, record):
for collector in self.collectors:
collector.collect(record)
class SourceOperator(StreamOperator):
"""
Operator to run a :class:`function.SourceFunction`
"""
class SourceContextImpl(function.SourceContext):
def __init__(self, collectors):
self.collectors = collectors
def collect(self, value):
for collector in self.collectors:
collector.collect(message.Record(value))
def __init__(self, func):
assert isinstance(func, function.SourceFunction)
super().__init__(func)
self.source_context = None
def open(self, collectors, runtime_context):
super().open(collectors, runtime_context)
self.source_context = SourceOperator.SourceContextImpl(collectors)
self.func.init(runtime_context.get_parallelism(),
runtime_context.get_task_index())
def run(self):
self.func.run(self.source_context)
def operator_type(self):
return OperatorType.SOURCE
class MapOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.MapFunction`
"""
def __init__(self, map_func: function.MapFunction):
assert isinstance(map_func, function.MapFunction)
super().__init__(map_func)
def process_element(self, record):
self.collect(message.Record(self.func.map(record.value)))
class FlatMapOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.FlatMapFunction`
"""
def __init__(self, flat_map_func: function.FlatMapFunction):
assert isinstance(flat_map_func, function.FlatMapFunction)
super().__init__(flat_map_func)
self.collection_collector = None
def open(self, collectors, runtime_context):
super().open(collectors, runtime_context)
self.collection_collector = streaming.collector.CollectionCollector(
collectors)
def process_element(self, record):
self.func.flat_map(record.value, self.collection_collector)
class FilterOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.FilterFunction`
"""
def __init__(self, filter_func: function.FilterFunction):
assert isinstance(filter_func, function.FilterFunction)
super().__init__(filter_func)
def process_element(self, record):
if self.func.filter(record.value):
self.collect(record)
class KeyByOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.KeyFunction`
"""
def __init__(self, key_func: function.KeyFunction):
assert isinstance(key_func, function.KeyFunction)
super().__init__(key_func)
def process_element(self, record):
key = self.func.key_by(record.value)
self.collect(message.KeyRecord(key, record.value))
class ReduceOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.ReduceFunction`
"""
def __init__(self, reduce_func: function.ReduceFunction):
assert isinstance(reduce_func, function.ReduceFunction)
super().__init__(reduce_func)
self.reduce_state = {}
def open(self, collectors, runtime_context):
super().open(collectors, runtime_context)
def process_element(self, record: message.KeyRecord):
key = record.key
value = record.value
if key in self.reduce_state:
old_value = self.reduce_state[key]
new_value = self.func.reduce(old_value, value)
self.reduce_state[key] = new_value
self.collect(message.Record(new_value))
else:
self.reduce_state[key] = value
self.collect(record)
class SinkOperator(StreamOperator, OneInputOperator):
"""
Operator to run a :class:`function.SinkFunction`
"""
def __init__(self, sink_func: function.SinkFunction):
assert isinstance(sink_func, function.SinkFunction)
super().__init__(sink_func)
def process_element(self, record):
self.func.sink(record.value)
_function_to_operator = {
function.SourceFunction: SourceOperator,
function.MapFunction: MapOperator,
function.FlatMapFunction: FlatMapOperator,
function.FilterFunction: FilterOperator,
function.KeyFunction: KeyByOperator,
function.ReduceFunction: ReduceOperator,
function.SinkFunction: SinkOperator,
}
def create_operator(func: function.Function):
"""Create an operator according to a :class:`function.Function`
Args:
func: a subclass of function.Function
Returns:
an operator
"""
operator_class = None
super_classes = func.__class__.mro()
for super_class in super_classes:
operator_class = _function_to_operator.get(super_class, None)
if operator_class is not None:
break
assert operator_class is not None
return operator_class(func)
+117
View File
@@ -0,0 +1,117 @@
import importlib
from abc import ABC, abstractmethod
import cloudpickle
from ray.streaming.runtime import gateway_client
class Partition(ABC):
"""Interface of the partitioning strategy."""
@abstractmethod
def partition(self, record, num_partition: int):
"""Given a record and downstream partitions, determine which partition(s)
should receive the record.
Args:
record: The record.
num_partition: num of partitions
Returns:
IDs of the downstream partitions that should receive the record.
"""
pass
class BroadcastPartition(Partition):
"""Broadcast the record to all downstream partitions."""
def __init__(self):
self.__partitions = []
def partition(self, record, num_partition: int):
if len(self.__partitions) != num_partition:
self.__partitions = list(range(num_partition))
return self.__partitions
class KeyPartition(Partition):
"""Partition the record by the key."""
def __init__(self):
self.__partitions = [-1]
def partition(self, key_record, num_partition: int):
# TODO support key group
self.__partitions[0] = abs(hash(key_record.key)) % num_partition
return self.__partitions
class RoundRobinPartition(Partition):
"""Partition record to downstream tasks in a round-robin matter."""
def __init__(self):
self.__partitions = [-1]
self.seq = 0
def partition(self, key_record, num_partition: int):
self.seq = (self.seq + 1) % num_partition
self.__partitions[0] = self.seq
return self.__partitions
class SimplePartition(Partition):
"""Wrap a python function as subclass of :class:`Partition`"""
def __init__(self, func):
self.func = func
def partition(self, record, num_partition: int):
return self.func(record, num_partition)
def serialize(partition_func):
"""
Serialize the partition function so that it can be deserialized by
:func:`deserialize`
"""
return cloudpickle.dumps(partition_func)
def deserialize(partition_bytes):
"""Deserialize the binary partition function serialized by
:func:`serialize`"""
return cloudpickle.loads(partition_bytes)
def load_partition(descriptor_partition_bytes: bytes):
"""
Deserialize `descriptor_partition_bytes` to get partition info, then
get or load partition function.
Note that this function must be kept in sync with
`org.ray.streaming.runtime.python.GraphPbBuilder.serializePartition`
Args:
descriptor_partition_bytes: serialized partition info
Returns:
partition function
"""
partition_bytes, module_name, class_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes:
return deserialize(partition_bytes)
else:
assert module_name
mod = importlib.import_module(module_name)
# If class_name is not None, user partition is a sub class
# of Partition.
# If function_name is not None, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
if class_name:
assert function_name is None
cls = getattr(mod, class_name)
return cls()
else:
assert function_name
func = getattr(mod, function_name)
return SimplePartition(func)
-222
View File
@@ -1,222 +0,0 @@
import logging
import sys
import time
import types
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
def _identity(element):
return element
class ReadTextFile:
"""A source operator instance that reads a text file line by line.
Attributes:
filepath (string): The path to the input file.
"""
def __init__(self, operator):
self.filepath = operator.other_args
# TODO (john): Handle possible exception here
self.reader = open(self.filepath, "r")
# Read input file line by line
def run(self, input_gate, output_gate):
while True:
record = self.reader.readline()
# Reader returns empty string ('') on EOF
if not record:
self.reader.close()
return
output_gate.push(
record[:-1]) # Push after removing newline characters
class Map:
"""A map operator instance that applies a user-defined
stream transformation.
A map produces exactly one output record for each record in
the input stream.
"""
def __init__(self, operator):
self.map_fn = operator.logic
# Applies the mapper each record of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
elements = 0
while True:
record = input_gate.pull()
if record is None:
return
output_gate.push(self.map_fn(record))
elements += 1
class FlatMap:
"""A map operator instance that applies a user-defined
stream transformation.
A flatmap produces one or more output records for each record in
the input stream.
Attributes:
flatmap_fn (function): The user-defined function.
"""
def __init__(self, operator):
self.flatmap_fn = operator.logic
# Applies the splitter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
output_gate.push_all(self.flatmap_fn(record))
class Filter:
"""A filter operator instance that applies a user-defined filter to
each record of the stream.
Output records are those that pass the filter, i.e. those for which
the filter function returns True.
Attributes:
filter_fn (function): The user-defined boolean function.
"""
def __init__(self, operator):
self.filter_fn = operator.logic
# Applies the filter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
if self.filter_fn(record):
output_gate.push(record)
class Inspect:
"""A inspect operator instance that inspects the content of the stream.
Inspect is useful for printing the records in the stream.
"""
def __init__(self, operator):
self.inspect_fn = operator.logic
def run(self, input_gate, output_gate):
# Applies the inspect logic (e.g. print) to the records of
# the input stream(s)
# and leaves stream unaffected by simply pushing the records to
# the output stream(s)
while True:
record = input_gate.pull()
if record is None:
return
if output_gate:
output_gate.push(record)
self.inspect_fn(record)
class Reduce:
"""A reduce operator instance that combines a new value for a key
with the last reduced one according to a user-defined logic.
"""
def __init__(self, operator):
self.reduce_fn = operator.logic
# Set the attribute selector
self.attribute_selector = operator.other_args
if self.attribute_selector is None:
self.attribute_selector = _identity
elif isinstance(self.attribute_selector, int):
self.key_index = self.attribute_selector
self.attribute_selector =\
lambda record: record[self.attribute_selector]
elif isinstance(self.attribute_selector, str):
self.attribute_selector =\
lambda record: vars(record)[self.attribute_selector]
elif not isinstance(self.attribute_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
self.state = {} # key -> value
# Combines the input value for a key with the last reduced
# value for that key to produce a new value.
# Outputs the result as (key,new value)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
key, rest = record
new_value = self.attribute_selector(rest)
# TODO (john): Is there a way to update state with
# a single dictionary lookup?
try:
old_value = self.state[key]
new_value = self.reduce_fn(old_value, new_value)
self.state[key] = new_value
except KeyError: # Key does not exist in state
self.state.setdefault(key, new_value)
output_gate.push((key, new_value))
# Returns the state of the actor
def get_state(self):
return self.state
class KeyBy:
"""A key_by operator instance that physically partitions the
stream based on a key.
"""
def __init__(self, operator):
# Set the key selector
self.key_selector = operator.other_args
if isinstance(self.key_selector, int):
self.key_selector = lambda r: r[self.key_selector]
elif isinstance(self.key_selector, str):
self.key_selector = lambda record: vars(record)[self.key_selector]
elif not isinstance(self.key_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
# The actual partitioning is done by the output gate
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
key = self.key_selector(record)
output_gate.push((key, record))
# A custom source actor
class Source:
def __init__(self, operator):
# The user-defined source with a get_next() method
self.source = operator.logic
# Starts the source by calling get_next() repeatedly
def run(self, input_gate, output_gate):
start = time.time()
elements = 0
while True:
record = self.source.get_next()
if not record:
logger.debug("[writer] puts per second: {}".format(
elements / (time.time() - start)))
return
output_gate.push(record)
elements += 1
@@ -0,0 +1,67 @@
# -*- coding: UTF-8 -*-
"""Module to interact between java and python
"""
import msgpack
import ray
class GatewayClient:
"""GatewayClient is used to interact with `PythonGateway` java actor"""
_PYTHON_GATEWAY_CLASSNAME = \
b"org.ray.streaming.runtime.python.PythonGateway"
def __init__(self):
self._python_gateway_actor = ray.java_actor_class(
GatewayClient._PYTHON_GATEWAY_CLASSNAME).remote()
def create_streaming_context(self):
call = self._python_gateway_actor.createStreamingContext.remote()
return deserialize(ray.get(call))
def with_config(self, conf):
call = self._python_gateway_actor.withConfig.remote(serialize(conf))
ray.get(call)
def execute(self, job_name):
call = self._python_gateway_actor.execute.remote(serialize(job_name))
ray.get(call)
def create_py_stream_source(self, serialized_func):
assert isinstance(serialized_func, bytes)
call = self._python_gateway_actor.createPythonStreamSource\
.remote(serialized_func)
return deserialize(ray.get(call))
def create_py_func(self, serialized_func):
assert isinstance(serialized_func, bytes)
call = self._python_gateway_actor.createPyFunc.remote(serialized_func)
return deserialize(ray.get(call))
def create_py_partition(self, serialized_partition):
assert isinstance(serialized_partition, bytes)
call = self._python_gateway_actor.createPyPartition\
.remote(serialized_partition)
return deserialize(ray.get(call))
def call_function(self, java_class, java_function, *args):
java_params = serialize([java_class, java_function] + list(args))
call = self._python_gateway_actor.callFunction.remote(java_params)
return deserialize(ray.get(call))
def call_method(self, java_object, java_method, *args):
java_params = serialize([java_object, java_method] + list(args))
call = self._python_gateway_actor.callMethod.remote(java_params)
return deserialize(ray.get(call))
def serialize(obj) -> bytes:
"""Serialize a python object which can be deserialized by `PythonGateway`
"""
return msgpack.packb(obj, use_bin_type=True)
def deserialize(data: bytes):
"""Deserialize the binary data serialized by `PythonGateway`"""
return msgpack.unpackb(data, raw=False)
+102
View File
@@ -0,0 +1,102 @@
import enum
import ray
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
import ray.streaming.generated.streaming_pb2 as streaming_pb
import ray.streaming.operator as operator
import ray.streaming.partition as partition
from ray.streaming import function
from ray.streaming.generated.streaming_pb2 import Language
class NodeType(enum.Enum):
"""
SOURCE: Sources are where your program reads its input from
TRANSFORM: Operators transform one or more DataStreams into a new
DataStream. Programs can combine multiple transformations into
sophisticated dataflow topologies.
SINK: Sinks consume DataStreams and forward them to files, sockets,
external systems, or print them.
"""
SOURCE = 0
TRANSFORM = 1
SINK = 2
class ExecutionNode:
def __init__(self, node_pb):
self.node_id = node_pb.node_id
self.node_type = NodeType[streaming_pb.NodeType.Name(
node_pb.node_type)]
self.parallelism = node_pb.parallelism
if node_pb.language == Language.PYTHON:
func_bytes = node_pb.function # python function descriptor
func = function.load_function(func_bytes)
self.stream_operator = operator.create_operator(func)
self.execution_tasks = [
ExecutionTask(task) for task in node_pb.execution_tasks
]
self.input_edges = [
ExecutionEdge(edge, node_pb.language)
for edge in node_pb.input_edges
]
self.output_edges = [
ExecutionEdge(edge, node_pb.language)
for edge in node_pb.output_edges
]
class ExecutionEdge:
def __init__(self, edge_pb, language):
self.src_node_id = edge_pb.src_node_id
self.target_node_id = edge_pb.target_node_id
partition_bytes = edge_pb.partition
if language == Language.PYTHON:
self.partition = partition.load_partition(partition_bytes)
class ExecutionTask:
def __init__(self, task_pb):
self.task_id = task_pb.task_id
self.task_index = task_pb.task_index
self.worker_actor = ray.actor.ActorHandle.\
_deserialization_helper(task_pb.worker_actor, False)
class ExecutionGraph:
def __init__(self, graph_pb: remote_call_pb.ExecutionGraph):
self._graph_pb = graph_pb
self.execution_nodes = [
ExecutionNode(node) for node in graph_pb.execution_nodes
]
def build_time(self):
return self._graph_pb.build_time
def execution_nodes(self):
return self.execution_nodes
def get_execution_task_by_task_id(self, task_id):
for execution_node in self.execution_nodes:
for task in execution_node.execution_tasks:
if task.task_id == task_id:
return task
raise Exception("Task %s does not exist!".format(task_id))
def get_execution_node_by_task_id(self, task_id):
for execution_node in self.execution_nodes:
for task in execution_node.execution_tasks:
if task.task_id == task_id:
return execution_node
raise Exception("Task %s does not exist!".format(task_id))
def get_task_id2_worker_by_node_id(self, node_id):
for execution_node in self.execution_nodes:
if execution_node.node_id == node_id:
task_id2_worker = {}
for task in execution_node.execution_tasks:
task_id2_worker[task.task_id] = task.worker_actor
return task_id2_worker
raise Exception("Node %s does not exist!".format(node_id))
+113
View File
@@ -0,0 +1,113 @@
import logging
from abc import ABC, abstractmethod
import ray.streaming.context as context
from ray.streaming import message
from ray.streaming.operator import OperatorType
logger = logging.getLogger(__name__)
class Processor(ABC):
"""The base interface for all processors."""
@abstractmethod
def open(self, collectors, runtime_context):
pass
@abstractmethod
def process(self, record: message.Record):
pass
@abstractmethod
def close(self):
pass
class StreamingProcessor(Processor, ABC):
"""StreamingProcessor is a process unit for a operator."""
def __init__(self, operator):
self.operator = operator
self.collectors = None
self.runtime_context = None
def open(self, collectors, runtime_context: context.RuntimeContext):
self.collectors = collectors
self.runtime_context = runtime_context
if self.operator is not None:
self.operator.open(collectors, runtime_context)
logger.info("Opened Processor {}".format(self))
def close(self):
pass
class SourceProcessor(StreamingProcessor):
"""Processor for :class:`ray.streaming.operator.SourceOperator` """
def __init__(self, operator):
super().__init__(operator)
def process(self, record):
raise Exception("SourceProcessor should not process record")
def run(self):
self.operator.run()
class OneInputProcessor(StreamingProcessor):
"""Processor for stream operator with one input"""
def __init__(self, operator):
super().__init__(operator)
def process(self, record):
self.operator.process_element(record)
class TwoInputProcessor(StreamingProcessor):
"""Processor for stream operator with two inputs"""
def __init__(self, operator):
super().__init__(operator)
self.left_stream = None
self.right_stream = None
def process(self, record: message.Record):
if record.stream == self.left_stream:
self.operator.process_element(record, None)
else:
self.operator.process_element(None, record)
@property
def left_stream(self):
return self.left_stream
@left_stream.setter
def left_stream(self, value):
self._left_stream = value
@property
def right_stream(self):
return self.right_stream
@right_stream.setter
def right_stream(self, value):
self.right_stream = value
def build_processor(operator_instance):
"""Create a processor for the given operator."""
operator_type = operator_instance.operator_type()
logger.info(
"Building StreamProcessor, operator type = {}, operator = {}.".format(
operator_type, operator_instance))
if operator_type == OperatorType.SOURCE:
return SourceProcessor(operator_instance)
elif operator_type == OperatorType.ONE_INPUT:
return OneInputProcessor(operator_instance)
elif operator_type == OperatorType.TWO_INPUT:
return TwoInputProcessor(operator_instance)
else:
raise Exception("Current operator type is not supported")
+158
View File
@@ -0,0 +1,158 @@
import logging
import pickle
import threading
from abc import ABC, abstractmethod
import ray
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
class StreamTask(ABC):
"""Base class for all streaming tasks. Each task runs a processor."""
def __init__(self, task_id, processor, worker):
self.task_id = task_id
self.processor = processor
self.worker = worker
self.reader = None # DataReader
self.writers = {} # ExecutionEdge -> DataWriter
self.thread = None
self.prepare_task()
self.thread = threading.Thread(target=self.run, daemon=True)
def prepare_task(self):
channel_conf = dict(self.worker.config)
channel_size = int(
self.worker.config.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT))
channel_conf[Config.CHANNEL_SIZE] = channel_size
channel_conf[Config.TASK_JOB_ID] = ray.runtime_context.\
_get_runtime_context().current_driver_id
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
execution_graph = self.worker.execution_graph
execution_node = self.worker.execution_node
# writers
collectors = []
for edge in execution_node.output_edges:
output_actor_ids = {}
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.target_node_id)
for target_task_id, target_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
execution_graph.build_time())
output_actor_ids[channel_name] = target_actor
if len(output_actor_ids) > 0:
channel_ids = list(output_actor_ids.keys())
to_actor_ids = list(output_actor_ids.values())
writer = DataWriter(channel_ids, to_actor_ids, channel_conf)
logger.info("Create DataWriter succeed.")
self.writers[edge] = writer
collectors.append(
OutputCollector(channel_ids, writer, edge.partition))
# readers
input_actor_ids = {}
for edge in execution_node.input_edges:
task_id2_worker = execution_graph.get_task_id2_worker_by_node_id(
edge.src_node_id)
for src_task_id, src_actor in task_id2_worker.items():
channel_name = ChannelID.gen_id(src_task_id, self.task_id,
execution_graph.build_time())
input_actor_ids[channel_name] = src_actor
if len(input_actor_ids) > 0:
channel_ids = list(input_actor_ids.keys())
from_actor_ids = list(input_actor_ids.values())
logger.info("Create DataReader, channels {}.".format(channel_ids))
self.reader = DataReader(channel_ids, from_actor_ids, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
# gets called to avoid crash
self.cancel_task()
import atexit
atexit.register(exit_handler)
runtime_context = RuntimeContextImpl(
self.worker.execution_task.task_id,
self.worker.execution_task.task_index, execution_node.parallelism)
logger.info("open Processor {}".format(self.processor))
self.processor.open(collectors, runtime_context)
@abstractmethod
def init(self):
pass
def start(self):
self.thread.start()
@abstractmethod
def run(self):
pass
@abstractmethod
def cancel_task(self):
pass
class InputStreamTask(StreamTask):
"""Base class for stream tasks that execute a
:class:`runtime.processor.OneInputProcessor` or
:class:`runtime.processor.TwoInputProcessor` """
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
self.running = True
self.stopped = False
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
def init(self):
pass
def run(self):
while self.running:
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
msg = pickle.loads(msg_data)
self.processor.process(msg)
self.stopped = True
def cancel_task(self):
self.running = False
while not self.stopped:
pass
class OneInputStreamTask(InputStreamTask):
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
"""
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
class SourceStreamTask(StreamTask):
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
"""
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
def init(self):
pass
def run(self):
self.processor.run()
def cancel_task(self):
pass
+104
View File
@@ -0,0 +1,104 @@
import logging
import ray
import ray.streaming._streaming as _streaming
import ray.streaming.generated.remote_call_pb2 as remote_call_pb
import ray.streaming.runtime.processor as processor
from ray._raylet import PythonFunctionDescriptor
from ray.streaming.config import Config
from ray.streaming.runtime.graph import ExecutionGraph
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
logger = logging.getLogger(__name__)
@ray.remote
class JobWorker(object):
"""A streaming job worker is used to execute user-defined function and
interact with `JobMaster`"""
def __init__(self):
self.worker_context = None
self.task_id = None
self.config = None
self.execution_graph = None
self.execution_task = None
self.execution_node = None
self.stream_processor = None
self.task = None
self.reader_client = None
self.writer_client = None
def init(self, worker_context_bytes):
worker_context = remote_call_pb.WorkerContext()
worker_context.ParseFromString(worker_context_bytes)
self.worker_context = worker_context
self.task_id = worker_context.task_id
self.config = worker_context.conf
execution_graph = ExecutionGraph(worker_context.graph)
self.execution_graph = execution_graph
self.execution_task = self.execution_graph. \
get_execution_task_by_task_id(self.task_id)
self.execution_node = self.execution_graph. \
get_execution_node_by_task_id(self.task_id)
operator = self.execution_node.stream_operator
self.stream_processor = processor.build_processor(operator)
logger.info(
"Initializing JobWorker, task_id: {}, operator: {}.".format(
self.task_id, self.stream_processor))
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
core_worker = ray.worker.global_worker.core_worker
reader_async_func = PythonFunctionDescriptor(
__name__, self.on_reader_message.__name__,
self.__class__.__name__)
reader_sync_func = PythonFunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
writer_async_func = PythonFunctionDescriptor(
__name__, self.on_writer_message.__name__,
self.__class__.__name__)
writer_sync_func = PythonFunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
self.task = self.create_stream_task()
self.task.start()
logger.info("JobWorker init succeed")
return True
def create_stream_task(self):
if isinstance(self.stream_processor, processor.SourceProcessor):
return SourceStreamTask(self.task_id, self.stream_processor, self)
elif isinstance(self.stream_processor, processor.OneInputProcessor):
return OneInputStreamTask(self.task_id, self.stream_processor,
self)
else:
raise Exception("Unsupported processor type: " +
type(self.stream_processor))
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
-689
View File
@@ -1,689 +0,0 @@
import logging
import pickle
import sys
import time
import networkx as nx
import ray
import ray.streaming.processor as processor
import ray.streaming.runtime.transfer as transfer
from ray.streaming.communication import DataChannel
from ray.streaming.config import Config
from ray.streaming.jobworker import JobWorker
from ray.streaming.operator import Operator, OpType
from ray.streaming.operator import PScheme, PStrategy
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
# Rolling sum's logic
def _sum(value_1, value_2):
return value_1 + value_2
# Partitioning strategies that require all-to-all instance communication
all_to_all_strategies = [
PStrategy.Shuffle, PStrategy.ShuffleByKey, PStrategy.Broadcast,
PStrategy.RoundRobin
]
# Environment configuration
class Conf:
"""Environment configuration.
This class includes all information about the configuration of the
streaming environment.
"""
def __init__(self, parallelism=1, channel_type=Config.MEMORY_CHANNEL):
self.parallelism = parallelism
self.channel_type = channel_type
# ...
class ExecutionGraph:
def __init__(self, env):
self.env = env
self.physical_topo = nx.DiGraph() # DAG
# Handles to all actors in the physical dataflow
self.actor_handles = []
# (op_id, op_instance_index) -> ActorID
self.actors_map = {}
# execution graph build time: milliseconds since epoch
self.build_time = 0
self.task_id_counter = 0
self.task_ids = {}
self.input_channels = {} # operator id -> input channels
self.output_channels = {} # operator id -> output channels
# Constructs and deploys a Ray actor of a specific type
# TODO (john): Actor placement information should be specified in
# the environment's configuration
def __generate_actor(self, instance_index, operator, input_channels,
output_channels):
"""Generates an actor that will execute a particular instance of
the logical operator
Attributes:
instance_index: The index of the instance the actor will execute.
operator: The metadata of the logical operator.
input_channels: The input channels of the instance.
output_channels The output channels of the instance.
"""
worker_id = (operator.id, instance_index)
# Record the physical dataflow graph (for debugging purposes)
self.__add_channel(worker_id, output_channels)
# Note direct_call only support pass by value
return JobWorker._remote(
args=[worker_id, operator, input_channels, output_channels],
is_direct_call=True)
# Constructs and deploys a Ray actor for each instance of
# the given operator
def __generate_actors(self, operator, upstream_channels,
downstream_channels):
"""Generates one actor for each instance of the given logical
operator.
Attributes:
operator (Operator): The logical operator metadata.
upstream_channels (list): A list of all upstream channels for
all instances of the operator.
downstream_channels (list): A list of all downstream channels
for all instances of the operator.
"""
num_instances = operator.num_instances
logger.info("Generating {} actors of type {}...".format(
num_instances, operator.type))
handles = []
for i in range(num_instances):
# Collect input and output channels for the particular instance
ip = [c for c in upstream_channels if c.dst_instance_index == i]
op = [c for c in downstream_channels if c.src_instance_index == i]
log = "Constructed {} input and {} output channels "
log += "for the {}-th instance of the {} operator."
logger.debug(log.format(len(ip), len(op), i, operator.type))
handle = self.__generate_actor(i, operator, ip, op)
if handle:
handles.append(handle)
self.actors_map[(operator.id, i)] = handle
return handles
# Adds a channel/edge to the physical dataflow graph
def __add_channel(self, actor_id, output_channels):
for c in output_channels:
dest_actor_id = (c.dst_operator_id, c.dst_instance_index)
self.physical_topo.add_edge(actor_id, dest_actor_id)
# Generates all required data channels between an operator
# and its downstream operators
def _generate_channels(self, operator):
"""Generates all output data channels
(see: DataChannel in communication.py) for all instances of
the given logical operator.
The function constructs one data channel for each pair of
communicating operator instances (instance_1,instance_2),
where instance_1 is an instance of the given operator and instance_2
is an instance of a direct downstream operator.
The number of total channels generated depends on the partitioning
strategy specified by the user.
"""
channels = {} # destination operator id -> channels
strategies = operator.partitioning_strategies
for dst_operator, p_scheme in strategies.items():
num_dest_instances = self.env.operators[dst_operator].num_instances
entry = channels.setdefault(dst_operator, [])
if p_scheme.strategy == PStrategy.Forward:
for i in range(operator.num_instances):
# ID of destination instance to connect
id = i % num_dest_instances
qid = self._gen_str_qid(operator.id, i, dst_operator, id)
c = DataChannel(operator.id, i, dst_operator, id, qid)
entry.append(c)
elif p_scheme.strategy in all_to_all_strategies:
for i in range(operator.num_instances):
for j in range(num_dest_instances):
qid = self._gen_str_qid(operator.id, i, dst_operator,
j)
c = DataChannel(operator.id, i, dst_operator, j, qid)
entry.append(c)
else:
# TODO (john): Add support for other partitioning strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
return channels
def _gen_str_qid(self, src_operator_id, src_instance_index,
dst_operator_id, dst_instance_index):
from_task_id = self.env.execution_graph.get_task_id(
src_operator_id, src_instance_index)
to_task_id = self.env.execution_graph.get_task_id(
dst_operator_id, dst_instance_index)
return transfer.ChannelID.gen_id(from_task_id, to_task_id,
self.build_time)
def _gen_task_id(self):
task_id = self.task_id_counter
self.task_id_counter += 1
return task_id
def get_task_id(self, op_id, op_instance_id):
return self.task_ids[(op_id, op_instance_id)]
def get_actor(self, op_id, op_instance_id):
return self.actors_map[(op_id, op_instance_id)]
# Prints the physical dataflow graph
def print_physical_graph(self):
logger.info("===================================")
logger.info("======Physical Dataflow Graph======")
logger.info("===================================")
# Print all data channels between operator instances
log = "(Source Operator ID,Source Operator Name,Source Instance ID)"
log += " --> "
log += "(Destination Operator ID,Destination Operator Name,"
log += "Destination Instance ID)"
logger.info(log)
for src_actor_id, dst_actor_id in self.physical_topo.edges:
src_operator_id, src_instance_index = src_actor_id
dst_operator_id, dst_instance_index = dst_actor_id
logger.info("({},{},{}) --> ({},{},{})".format(
src_operator_id, self.env.operators[src_operator_id].name,
src_instance_index, dst_operator_id,
self.env.operators[dst_operator_id].name, dst_instance_index))
def build_graph(self):
self.build_channels()
# to support cyclic reference serialization
try:
ray.register_custom_serializer(Environment, use_pickle=True)
ray.register_custom_serializer(ExecutionGraph, use_pickle=True)
ray.register_custom_serializer(OpType, use_pickle=True)
ray.register_custom_serializer(PStrategy, use_pickle=True)
except Exception:
# local mode can't use pickle
pass
# Each operator instance is implemented as a Ray actor
# Actors are deployed in topological order, as we traverse the
# logical dataflow from sources to sinks.
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
# Instantiate Ray actors
handles = self.__generate_actors(
operator, self.input_channels.get(node, []),
self.output_channels.get(node, []))
if handles:
self.actor_handles.extend(handles)
def build_channels(self):
self.build_time = int(time.time() * 1000)
# gen auto-incremented unique task id for every operator instance
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
for i in range(operator.num_instances):
operator_instance_id = (operator.id, i)
self.task_ids[operator_instance_id] = self._gen_task_id()
channels = {}
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
# Generate downstream data channels
downstream_channels = self._generate_channels(operator)
channels[node] = downstream_channels
# op_id -> channels
input_channels = {}
output_channels = {}
for op_id, all_downstream_channels in channels.items():
for dst_op_channels in all_downstream_channels.values():
for c in dst_op_channels:
dst = input_channels.setdefault(c.dst_operator_id, [])
dst.append(c)
src = output_channels.setdefault(c.src_operator_id, [])
src.append(c)
self.input_channels = input_channels
self.output_channels = output_channels
# The execution environment for a streaming job
class Environment:
"""A streaming environment.
This class is responsible for constructing the logical and the
physical dataflow.
Attributes:
logical_topo (DiGraph): The user-defined logical topology in
NetworkX DiGRaph format.
(See: https://networkx.github.io)
physical_topo (DiGraph): The physical topology in NetworkX
DiGRaph format. The physical dataflow is constructed by the
environment based on logical_topo.
operators (dict): A mapping from operator ids to operator metadata
(See: Operator in operator.py).
config (Config): The environment's configuration.
topo_cleaned (bool): A flag that indicates whether the logical
topology is garbage collected (True) or not (False).
actor_handles (list): A list of all Ray actor handles that execute
the streaming dataflow.
"""
def __init__(self, config=Conf()):
self.logical_topo = nx.DiGraph() # DAG
self.operators = {} # operator id --> operator object
self.config = config # Environment's configuration
self.topo_cleaned = False
self.operator_id_counter = 0
self.execution_graph = None # set when executed
def gen_operator_id(self):
op_id = self.operator_id_counter
self.operator_id_counter += 1
return op_id
# An edge denotes a flow of data between logical operators
# and may correspond to multiple data channels in the physical dataflow
def _add_edge(self, source, destination):
self.logical_topo.add_edge(source, destination)
# Cleans the logical dataflow graph to construct and
# deploy the physical dataflow
def _collect_garbage(self):
if self.topo_cleaned is True:
return
for node in self.logical_topo:
self.operators[node]._clean()
self.topo_cleaned = True
# Sets the level of parallelism for a registered operator
# Overwrites the environment parallelism (if set)
def _set_parallelism(self, operator_id, level_of_parallelism):
self.operators[operator_id].num_instances = level_of_parallelism
# Sets the same level of parallelism for all operators in the environment
def set_parallelism(self, parallelism):
self.config.parallelism = parallelism
# Creates and registers a user-defined data source
# TODO (john): There should be different types of sources, e.g. sources
# reading from Kafka, text files, etc.
# TODO (john): Handle case where environment parallelism is set
def source(self, source):
source_id = self.gen_operator_id()
source_stream = DataStream(self, source_id)
self.operators[source_id] = Operator(
source_id, OpType.Source, processor.Source, "Source", logic=source)
return source_stream
# Creates and registers a new data source that reads a
# text file line by line
# TODO (john): There should be different types of sources,
# e.g. sources reading from Kafka, text files, etc.
# TODO (john): Handle case where environment parallelism is set
def read_text_file(self, filepath):
source_id = self.gen_operator_id()
source_stream = DataStream(self, source_id)
self.operators[source_id] = Operator(
source_id,
OpType.ReadTextFile,
processor.ReadTextFile,
"Read Text File",
other=filepath)
return source_stream
# Constructs and deploys the physical dataflow
def execute(self):
"""Deploys and executes the physical dataflow."""
self._collect_garbage() # Make sure everything is clean
# TODO (john): Check if dataflow has any 'logical inconsistencies'
# For example, if there is a forward partitioning strategy but
# the number of downstream instances is larger than the number of
# upstream instances, some of the downstream instances will not be
# used at all
self.execution_graph = ExecutionGraph(self)
self.execution_graph.build_graph()
logger.info("init...")
# init
init_waits = []
for actor_handle in self.execution_graph.actor_handles:
init_waits.append(actor_handle.init.remote(pickle.dumps(self)))
for wait in init_waits:
assert ray.get(wait) is True
logger.info("running...")
# start
exec_handles = []
for actor_handle in self.execution_graph.actor_handles:
exec_handles.append(actor_handle.start.remote())
return exec_handles
def wait_finish(self):
for actor_handle in self.execution_graph.actor_handles:
while not ray.get(actor_handle.is_finished.remote()):
time.sleep(1)
# Prints the logical dataflow graph
def print_logical_graph(self):
self._collect_garbage()
logger.info("==================================")
logger.info("======Logical Dataflow Graph======")
logger.info("==================================")
# Print operators in topological order
for node in nx.topological_sort(self.logical_topo):
downstream_neighbors = list(self.logical_topo.neighbors(node))
logger.info("======Current Operator======")
operator = self.operators[node]
operator.print()
logger.info("======Downstream Operators======")
if len(downstream_neighbors) == 0:
logger.info("None\n")
for downstream_node in downstream_neighbors:
self.operators[downstream_node].print()
# TODO (john): We also need KeyedDataStream and WindowedDataStream as
# subclasses of DataStream to prevent ill-defined logical dataflows
# A DataStream corresponds to an edge in the logical dataflow
class DataStream:
"""A data stream.
This class contains all information about a logical stream, i.e. an edge
in the logical topology. It is the main class exposed to the user.
Attributes:
id (UUID): The id of the stream
env (Environment): The environment the stream belongs to.
src_operator_id (UUID): The id of the source operator of the stream.
dst_operator_id (UUID): The id of the destination operator of the
stream.
is_partitioned (bool): Denotes if there is a partitioning strategy
(e.g. shuffle) for the stream or not (default stategy: Forward).
"""
stream_id_counter = 0
def __init__(self,
environment,
source_id=None,
dest_id=None,
is_partitioned=False):
self.env = environment
self.id = DataStream.stream_id_counter
DataStream.stream_id_counter += 1
self.src_operator_id = source_id
self.dst_operator_id = dest_id
# True if a partitioning strategy for this stream exists,
# false otherwise
self.is_partitioned = is_partitioned
# Generates a new stream after a data transformation is applied
def __expand(self):
stream = DataStream(self.env)
assert (self.dst_operator_id is not None)
stream.src_operator_id = self.dst_operator_id
stream.dst_operator_id = None
return stream
# Assigns the partitioning strategy to a new 'open-ended' stream
# and returns the stream. At this point, the partitioning strategy
# is not associated with any destination operator. We expect this to
# be done later, as we continue assembling the dataflow graph
def __partition(self, strategy, partition_fn=None):
scheme = PScheme(strategy, partition_fn)
source_operator = self.env.operators[self.src_operator_id]
new_stream = DataStream(
self.env, source_id=source_operator.id, is_partitioned=True)
source_operator._set_partition_strategy(new_stream.id, scheme)
return new_stream
# Registers the operator to the environment and returns a new
# 'open-ended' stream. The registered operator serves as the destination
# of the previously 'open' stream
def __register(self, operator):
"""Registers the given logical operator to the environment and
connects it to its upstream operator (if any).
A call to this function adds a new edge to the logical topology.
Attributes:
operator (Operator): The metadata of the logical operator.
"""
self.env.operators[operator.id] = operator
self.dst_operator_id = operator.id
logger.debug("Adding new dataflow edge ({},{}) --> ({},{})".format(
self.src_operator_id,
self.env.operators[self.src_operator_id].name,
self.dst_operator_id,
self.env.operators[self.dst_operator_id].name))
# Update logical dataflow graphs
self.env._add_edge(self.src_operator_id, self.dst_operator_id)
# Keep track of the partitioning strategy and the destination operator
src_operator = self.env.operators[self.src_operator_id]
if self.is_partitioned is True:
partitioning, _ = src_operator._get_partition_strategy(self.id)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
elif src_operator.type == OpType.KeyBy:
# Set the output partitioning strategy to shuffle by key
partitioning = PScheme(PStrategy.ShuffleByKey)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
else: # No partitioning strategy has been defined - set default
partitioning = PScheme(PStrategy.Forward)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
return self.__expand()
# Sets the level of parallelism for an operator, i.e. its total
# number of instances. Each operator instance corresponds to an actor
# in the physical dataflow
def set_parallelism(self, num_instances):
"""Sets the number of instances for the source operator of the stream.
Attributes:
num_instances (int): The level of parallelism for the source
operator of the stream.
"""
assert (num_instances > 0)
self.env._set_parallelism(self.src_operator_id, num_instances)
return self
# Stream Partitioning Strategies #
# TODO (john): Currently, only forward (default), shuffle,
# and broadcast are supported
# Hash-based record shuffling
def shuffle(self):
"""Registers a shuffling partitioning strategy for the stream."""
return self.__partition(PStrategy.Shuffle)
# Broadcasts each record to all downstream instances
def broadcast(self):
"""Registers a broadcast partitioning strategy for the stream."""
return self.__partition(PStrategy.Broadcast)
# Rescales load to downstream instances
def rescale(self):
"""Registers a rescale partitioning strategy for the stream.
Same as Flink's rescale (see: https://ci.apache.org/projects/flink/
flink-docs-stable/dev/stream/operators/#physical-partitioning).
"""
return self.__partition(PStrategy.Rescale)
# Round-robin partitioning
def round_robin(self):
"""Registers a round-robin partitioning strategy for the stream."""
return self.__partition(PStrategy.RoundRobin)
# User-defined partitioning
def partition(self, partition_fn):
"""Registers a user-defined partitioning strategy for the stream.
Attributes:
partition_fn (function): The user-defined partitioning function.
"""
return self.__partition(PStrategy.Custom, partition_fn)
# Data Trasnformations #
# TODO (john): Expand set of supported operators.
# TODO (john): To support event-time windows we need a mechanism for
# generating and processing watermarks
# Registers map operator to the environment
def map(self, map_fn, name="Map"):
"""Applies a map operator to the stream.
Attributes:
map_fn (function): The user-defined logic of the map.
"""
op = Operator(
self.env.gen_operator_id(),
OpType.Map,
processor.Map,
name,
map_fn,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers flatmap operator to the environment
def flat_map(self, flatmap_fn):
"""Applies a flatmap operator to the stream.
Attributes:
flatmap_fn (function): The user-defined logic of the flatmap
(e.g. split()).
"""
op = Operator(
self.env.gen_operator_id(),
OpType.FlatMap,
processor.FlatMap,
"FlatMap",
flatmap_fn,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers keyBy operator to the environment
# TODO (john): This should returned a KeyedDataStream
def key_by(self, key_selector):
"""Applies a key_by operator to the stream.
Attributes:
key_attribute_index (int): The index of the key attributed
(assuming tuple records).
"""
op = Operator(
self.env.gen_operator_id(),
OpType.KeyBy,
processor.KeyBy,
"KeyBy",
other=key_selector,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers Reduce operator to the environment
def reduce(self, reduce_fn):
"""Applies a rolling sum operator to the stream.
Attributes:
sum_attribute_index (int): The index of the attribute to sum
(assuming tuple records).
"""
op = Operator(
self.env.gen_operator_id(),
OpType.Reduce,
processor.Reduce,
"Sum",
reduce_fn,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers Sum operator to the environment
def sum(self, attribute_selector, state_keeper=None):
"""Applies a rolling sum operator to the stream.
Attributes:
sum_attribute_index (int): The index of the attribute to sum
(assuming tuple records).
"""
op = Operator(
self.env.gen_operator_id(),
OpType.Sum,
processor.Reduce,
"Sum",
_sum,
other=attribute_selector,
state_actor=state_keeper,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers window operator to the environment.
# This is a system time window
# TODO (john): This should return a WindowedDataStream
def time_window(self, window_width_ms):
"""Applies a system time window to the stream.
Attributes:
window_width_ms (int): The length of the window in ms.
"""
raise Exception("time_window is unsupported")
# Registers filter operator to the environment
def filter(self, filter_fn):
"""Applies a filter to the stream.
Attributes:
filter_fn (function): The user-defined filter function.
"""
op = Operator(
self.env.gen_operator_id(),
OpType.Filter,
processor.Filter,
"Filter",
filter_fn,
num_instances=self.env.config.parallelism)
return self.__register(op)
# TODO (john): Registers window join operator to the environment
def window_join(self, other_stream, join_attribute, window_width):
op = Operator(
self.env.gen_operator_id(),
OpType.WindowJoin,
processor.WindowJoin,
"WindowJoin",
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers inspect operator to the environment
def inspect(self, inspect_logic):
"""Inspects the content of the stream.
Attributes:
inspect_logic (function): The user-defined inspect function.
"""
op = Operator(
self.env.gen_operator_id(),
OpType.Inspect,
processor.Inspect,
"Inspect",
inspect_logic,
num_instances=self.env.config.parallelism)
return self.__register(op)
# Registers sink operator to the environment
# TODO (john): A sink now just drops records but it should be able to
# export data to other systems
def sink(self):
"""Closes the stream with a sink operator."""
op = Operator(
self.env.gen_operator_id(),
OpType.Sink,
processor.Sink,
"Sink",
num_instances=self.env.config.parallelism)
return self.__register(op)
+22
View File
@@ -0,0 +1,22 @@
from ray.streaming import function
from ray.streaming.runtime import gateway_client
def test_get_simple_function_class():
simple_map_func_class = function._get_simple_function_class(
function.MapFunction)
assert simple_map_func_class is function.SimpleMapFunction
class MapFunc(function.MapFunction):
def map(self, value):
return str(value)
def test_load_function():
# function_bytes, module_name, class_name, function_name,
# function_interface
descriptor_func_bytes = gateway_client.serialize(
[None, __name__, MapFunc.__name__, None, "MapFunction"])
func = function.load_function(descriptor_func_bytes)
assert type(func) is MapFunc
@@ -1,206 +0,0 @@
from ray.streaming.streaming import Environment, ExecutionGraph
from ray.streaming.operator import OpType, PStrategy
def test_parallelism():
"""Tests operator parallelism."""
env = Environment()
# Try setting a common parallelism for all operators
env.set_parallelism(2)
stream = env.source(None).map(None).filter(None).flat_map(None)
env._collect_garbage()
for operator in env.operators.values():
if operator.type == OpType.Source:
# TODO (john): Currently each source has only one instance
assert operator.num_instances == 1, (operator.num_instances, 1)
else:
assert operator.num_instances == 2, (operator.num_instances, 2)
# Check again after adding an operator with different parallelism
stream.map(None, "Map1").shuffle().set_parallelism(3).map(
None, "Map2").set_parallelism(4)
env._collect_garbage()
for operator in env.operators.values():
if operator.type == OpType.Source:
assert operator.num_instances == 1, (operator.num_instances, 1)
elif operator.name != "Map1" and operator.name != "Map2":
assert operator.num_instances == 2, (operator.num_instances, 2)
elif operator.name != "Map2":
assert operator.num_instances == 3, (operator.num_instances, 3)
else:
assert operator.num_instances == 4, (operator.num_instances, 4)
def test_partitioning():
"""Tests stream partitioning."""
env = Environment()
# Try defining multiple partitioning strategies for the same stream
_ = env.source(None).shuffle().rescale().broadcast().map(
None).broadcast().shuffle()
env._collect_garbage()
for operator in env.operators.values():
p_schemes = operator.partitioning_strategies
for scheme in p_schemes.values():
# Only last defined strategy should be kept
if operator.type == OpType.Source:
assert scheme.strategy == PStrategy.Broadcast, (
scheme.strategy, PStrategy.Broadcast)
else:
assert scheme.strategy == PStrategy.Shuffle, (
scheme.strategy, PStrategy.Shuffle)
def test_forking():
"""Tests stream forking."""
env = Environment()
# Try forking a stream
stream = env.source(None).map(None).set_parallelism(2)
# First branch with a shuffle partitioning strategy
_ = stream.shuffle().key_by(0).sum(1)
# Second branch with the default partitioning strategy
_ = stream.key_by(1).sum(2)
env._collect_garbage()
# Operator ids
source_id = None
map_id = None
keyby1_id = None
keyby2_id = None
sum1_id = None
sum2_id = None
# Collect ids
for id, operator in env.operators.items():
if operator.type == OpType.Source:
source_id = id
elif operator.type == OpType.Map:
map_id = id
elif operator.type == OpType.KeyBy:
if operator.other_args == 0:
keyby1_id = id
else:
assert operator.other_args == 1, (operator.other_args, 1)
keyby2_id = id
elif operator.type == OpType.Sum:
if operator.other_args == 1:
sum1_id = id
else:
assert operator.other_args == 2, (operator.other_args, 2)
sum2_id = id
# Check generated streams and their partitioning
for source, destination in env.logical_topo.edges:
operator = env.operators[source]
if source == source_id:
assert destination == map_id, (destination, map_id)
elif source == map_id:
p_scheme = operator.partitioning_strategies[destination]
strategy = p_scheme.strategy
key_index = env.operators[destination].other_args
if key_index == 0: # This must be the first branch
assert strategy == PStrategy.Shuffle, (strategy,
PStrategy.Shuffle)
assert destination == keyby1_id, (destination, keyby1_id)
else: # This must be the second branch
assert key_index == 1, (key_index, 1)
assert strategy == PStrategy.Forward, (strategy,
PStrategy.Forward)
assert destination == keyby2_id, (destination, keyby2_id)
elif source == keyby1_id or source == keyby2_id:
p_scheme = operator.partitioning_strategies[destination]
strategy = p_scheme.strategy
key_index = env.operators[destination].other_args
if key_index == 1: # This must be the first branch
assert strategy == PStrategy.ShuffleByKey, (
strategy, PStrategy.ShuffleByKey)
assert destination == sum1_id, (destination, sum1_id)
else: # This must be the second branch
assert key_index == 2, (key_index, 2)
assert strategy == PStrategy.ShuffleByKey, (
strategy, PStrategy.ShuffleByKey)
assert destination == sum2_id, (destination, sum2_id)
else: # This must be a sum operator
assert operator.type == OpType.Sum, (operator.type, OpType.Sum)
def _test_shuffle_channels():
"""Tests shuffling connectivity."""
env = Environment()
# Try defining a shuffle
_ = env.source(None).shuffle().map(None).set_parallelism(4)
expected = [(0, 0), (0, 1), (0, 2), (0, 3)]
_test_channels(env, expected)
def _test_forward_channels():
"""Tests forward connectivity."""
env = Environment()
# Try the default partitioning strategy
_ = env.source(None).set_parallelism(4).map(None).set_parallelism(2)
expected = [(0, 0), (1, 1), (2, 0), (3, 1)]
_test_channels(env, expected)
def _test_broadcast_channels():
"""Tests broadcast connectivity."""
env = Environment()
# Try broadcasting
_ = env.source(None).set_parallelism(4).broadcast().map(
None).set_parallelism(2)
expected = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)]
_test_channels(env, expected)
def _test_round_robin_channels():
"""Tests round-robin connectivity."""
env = Environment()
# Try broadcasting
_ = env.source(None).round_robin().map(None).set_parallelism(2)
expected = [(0, 0), (0, 1)]
_test_channels(env, expected)
def _test_channels(environment, expected_channels):
"""Tests operator connectivity."""
environment._collect_garbage()
map_id = None
# Get id
for id, operator in environment.operators.items():
if operator.type == OpType.Map:
map_id = id
# Collect channels
environment.execution_graph = ExecutionGraph(environment)
environment.execution_graph.build_channels()
channels_per_destination = []
for operator in environment.operators.values():
channels_per_destination.append(
environment.execution_graph._generate_channels(operator))
# Check actual connectivity
actual = []
for destination in channels_per_destination:
for channels in destination.values():
for channel in channels:
src_instance_index = channel.src_instance_index
dst_instance_index = channel.dst_instance_index
connection = (src_instance_index, dst_instance_index)
assert channel.dst_operator_id == map_id, (
channel.dst_operator_id, map_id)
actual.append(connection)
# Make sure connections are as expected
set_1 = set(expected_channels)
set_2 = set(actual)
assert set_1 == set_2, (set_1, set_2)
def test_channel_generation():
"""Tests data channel generation."""
_test_shuffle_channels()
_test_broadcast_channels()
_test_round_robin_channels()
_test_forward_channels()
# TODO (john): Add simple wordcount test
def test_wordcount():
"""Tests a simple streaming wordcount."""
pass
if __name__ == "__main__":
test_channel_generation()
+8
View File
@@ -0,0 +1,8 @@
from ray.streaming import operator
from ray.streaming import function
def test_create_operator():
map_func = function.SimpleMapFunction(lambda x: x)
map_operator = operator.create_operator(map_func)
assert type(map_operator) is operator.MapOperator
+15 -10
View File
@@ -1,18 +1,23 @@
import ray
from ray.streaming.config import Config
from ray.streaming.streaming import Environment, Conf
from ray.streaming import StreamingContext
def test_word_count():
ray.init()
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
env.read_text_file(__file__) \
ray.init(load_code_from_local=True, include_java=True)
ctx = StreamingContext.Builder() \
.build()
ctx.read_text_file(__file__) \
.set_parallelism(1) \
.filter(lambda x: "word" in x) \
.inspect(lambda x: print("result", x))
env_handle = env.execute()
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
.flat_map(lambda x: x.split()) \
.map(lambda x: (x, 1)) \
.key_by(lambda x: x[0]) \
.reduce(lambda old_value, new_value:
(old_value[0], old_value[1] + new_value[1])) \
.filter(lambda x: "ray" not in x) \
.sink(lambda x: print("result", x))
ctx.submit("word_count")
import time
time.sleep(3)
ray.shutdown()