mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 12:08:34 +08:00
Liskov'ify feed & merge.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
Component
|
||||
Abstract base class for Feed and Merge.
|
||||
|
||||
Component
|
||||
|
|
||||
Aggregate
|
||||
Aggregate
|
||||
|
|
||||
/ \
|
||||
Feed Merge
|
||||
@@ -25,27 +27,11 @@ class Aggregate(Component):
|
||||
- pull_socket
|
||||
- feed_socket
|
||||
|
||||
Both use ``data_buffer`` for buffering.
|
||||
|
||||
Feed and Merge define these differently.
|
||||
"""
|
||||
|
||||
def init(self):
|
||||
self.sent_count = 0
|
||||
self.received_count = 0
|
||||
self.draining = False
|
||||
self.ds_finished_counter = 0
|
||||
|
||||
# Depending on the size of this, might want to use a data
|
||||
# structure with better asymptotics.
|
||||
self.data_buffer = {}
|
||||
|
||||
# source_id -> integer count
|
||||
self.sent_counters = Counter()
|
||||
self.recv_counters = Counter()
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def get_type(self):
|
||||
return COMPONENT_TYPE.CONDUIT
|
||||
@@ -54,10 +40,6 @@ class Aggregate(Component):
|
||||
# Core Methods
|
||||
# -------------
|
||||
|
||||
def open(self):
|
||||
self.pull_socket = self.bind_data()
|
||||
self.feed_socket = self.bind_feed()
|
||||
|
||||
def do_work(self):
|
||||
# wait for synchronization reply from the host
|
||||
socks = dict(self.poll.poll(self.heartbeat_timeout))
|
||||
@@ -131,55 +113,16 @@ class Aggregate(Component):
|
||||
if not (self.is_full() or self.draining):
|
||||
return
|
||||
|
||||
# TODO: implement this in __iter__
|
||||
event = self.next()
|
||||
if event is not None:
|
||||
if(event != None):
|
||||
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
|
||||
self.sent_counters[event.source_id] += 1
|
||||
self.sent_count += 1
|
||||
|
||||
def append(self, event):
|
||||
"""
|
||||
Add an event to the buffer for the source specified by
|
||||
source_id.
|
||||
"""
|
||||
self.data_buffer[event.source_id].append(event)
|
||||
self.recv_counters[event.source_id] += 1
|
||||
self.received_count += 1
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Get the next message in chronological order.
|
||||
"""
|
||||
if not(self.is_full() or self.draining):
|
||||
return
|
||||
|
||||
cur_source = None
|
||||
earliest_source = None
|
||||
earliest_event = None
|
||||
#iterate over the queues of events from all sources
|
||||
#(1 queue per datasource)
|
||||
for events in self.data_buffer.itervalues():
|
||||
if len(events) == 0:
|
||||
continue
|
||||
cur_source = events
|
||||
first_in_list = events[0]
|
||||
if first_in_list.dt == None:
|
||||
#this is a filler event, discard
|
||||
events.pop(0)
|
||||
continue
|
||||
|
||||
if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt):
|
||||
earliest_event = first_in_list
|
||||
earliest_source = cur_source
|
||||
|
||||
if earliest_event != None:
|
||||
return earliest_source.pop(0)
|
||||
|
||||
def is_full(self):
|
||||
"""
|
||||
Indicates whether the buffer has messages in buffer for
|
||||
all un-DONE, blocking sources.
|
||||
Indicates whether the buffer has messages in buffer for all
|
||||
un-DONE, blocking sources.
|
||||
"""
|
||||
for source_id, events in self.data_buffer.iteritems():
|
||||
if len(events) == 0:
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from collections import Counter
|
||||
|
||||
from zipline.core.component import Component
|
||||
from zipline.components.aggregator import Aggregate
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
|
||||
@@ -9,7 +10,7 @@ from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
|
||||
|
||||
LOGGER = logging.getLogger('ZiplineLogger')
|
||||
|
||||
class Feed(Component):
|
||||
class Feed(Aggregate):
|
||||
"""
|
||||
Connects to N PULL sockets, publishing all messages received to a
|
||||
PUB socket. Published messages are guaranteed to be in chronological
|
||||
@@ -35,71 +36,17 @@ class Feed(Component):
|
||||
def get_id(self):
|
||||
return "FEED"
|
||||
|
||||
@property
|
||||
def get_type(self):
|
||||
return COMPONENT_TYPE.CONDUIT
|
||||
|
||||
# -------------
|
||||
# Core Methods
|
||||
# -------------
|
||||
# -------
|
||||
# Sockets
|
||||
# -------
|
||||
|
||||
def open(self):
|
||||
self.pull_socket = self.bind_data()
|
||||
self.feed_socket = self.bind_feed()
|
||||
|
||||
def do_work(self):
|
||||
# wait for synchronization reply from the host
|
||||
socks = dict(self.poll.poll(self.heartbeat_timeout))
|
||||
|
||||
# TODO: Abstract this out, maybe on base component
|
||||
if self.control_in in socks and socks[self.control_in] == self.zmq.POLLIN:
|
||||
msg = self.control_in.recv()
|
||||
event, payload = CONTROL_UNFRAME(msg)
|
||||
|
||||
# -- Heartbeat --
|
||||
if event == CONTROL_PROTOCOL.HEARTBEAT:
|
||||
# Heart outgoing
|
||||
heartbeat_frame = CONTROL_FRAME(
|
||||
CONTROL_PROTOCOL.OK,
|
||||
payload
|
||||
)
|
||||
self.control_out.send(heartbeat_frame)
|
||||
|
||||
# -- Soft Kill --
|
||||
elif event == CONTROL_PROTOCOL.SHUTDOWN:
|
||||
self.signal_done()
|
||||
self.shutdown()
|
||||
|
||||
# -- Hard Kill --
|
||||
elif event == CONTROL_PROTOCOL.KILL:
|
||||
self.kill()
|
||||
|
||||
|
||||
if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN:
|
||||
message = self.pull_socket.recv()
|
||||
|
||||
if message == str(CONTROL_PROTOCOL.DONE):
|
||||
self.ds_finished_counter += 1
|
||||
|
||||
if len(self.data_buffer) == self.ds_finished_counter:
|
||||
#drain any remaining messages in the buffer
|
||||
LOGGER.debug("draining feed")
|
||||
self.drain()
|
||||
self.signal_done()
|
||||
else:
|
||||
try:
|
||||
event = self.unframe(message)
|
||||
# deserialization error
|
||||
except zp.INVALID_DATASOURCE_FRAME as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
try:
|
||||
self.append(event)
|
||||
self.send_next()
|
||||
|
||||
# Invalid message
|
||||
except zp.INVALID_DATASOURCE_FRAME as exc:
|
||||
return self.signal_exception(exc)
|
||||
# -------------
|
||||
# Core Methods
|
||||
# -------------
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.DATASOURCE_UNFRAME(msg)
|
||||
@@ -169,36 +116,3 @@ class Feed(Component):
|
||||
|
||||
if earliest_event != None:
|
||||
return earliest_source.pop(0)
|
||||
|
||||
def is_full(self):
|
||||
"""
|
||||
Indicates whether the buffer has messages in buffer for
|
||||
all un-DONE, blocking sources.
|
||||
"""
|
||||
for source_id, events in self.data_buffer.iteritems():
|
||||
if len(events) == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def pending_messages(self):
|
||||
"""
|
||||
Returns the count of all events from all sources in the
|
||||
buffer.
|
||||
"""
|
||||
total = 0
|
||||
for events in self.data_buffer.values():
|
||||
total += len(events)
|
||||
return total
|
||||
|
||||
def add_source(self, source_id):
|
||||
"""
|
||||
Add a data source to the buffer.
|
||||
"""
|
||||
self.data_buffer[source_id] = []
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Buffer's length is same as internal map holding separate sorted
|
||||
arrays of events keyed by source id.
|
||||
"""
|
||||
return len(self.data_buffer)
|
||||
|
||||
+19
-10
@@ -2,10 +2,11 @@ from feed import Feed
|
||||
|
||||
import zipline.protocol as zp
|
||||
from zipline.protocol import COMPONENT_TYPE
|
||||
from zipline.components.aggregator import Aggregate
|
||||
|
||||
from collections import Counter
|
||||
|
||||
class Merge(Feed):
|
||||
class Merge(Aggregate):
|
||||
"""
|
||||
Merges multiple streams of events into single messages.
|
||||
"""
|
||||
@@ -27,14 +28,28 @@ class Merge(Feed):
|
||||
def get_id(self):
|
||||
return "MERGE"
|
||||
|
||||
@property
|
||||
def get_type(self):
|
||||
return COMPONENT_TYPE.CONDUIT
|
||||
# -------
|
||||
# Sockets
|
||||
# -------
|
||||
|
||||
def open(self):
|
||||
self.pull_socket = self.bind_merge()
|
||||
self.feed_socket = self.bind_result()
|
||||
|
||||
# -------
|
||||
# Framing
|
||||
# -------
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.TRANSFORM_UNFRAME(msg)
|
||||
|
||||
def frame(self, event):
|
||||
return zp.MERGE_FRAME(event)
|
||||
|
||||
# ---------
|
||||
# Data Flow
|
||||
# ---------
|
||||
|
||||
def next(self):
|
||||
"""Get the next merged message from the feed buffer."""
|
||||
if not (self.is_full() or self.draining):
|
||||
@@ -53,12 +68,6 @@ class Merge(Feed):
|
||||
result.merge(cur)
|
||||
return result
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.TRANSFORM_UNFRAME(msg)
|
||||
|
||||
def frame(self, event):
|
||||
return zp.MERGE_FRAME(event)
|
||||
|
||||
def append(self, event):
|
||||
"""
|
||||
:param event: a ndict with one entry. key is the name of the
|
||||
|
||||
Reference in New Issue
Block a user