Liskov'ify feed & merge.

This commit is contained in:
Stephen Diehl
2012-05-16 17:30:10 -04:00
parent b1f8bbd189
commit 7f4c94d4e1
3 changed files with 36 additions and 170 deletions
+9 -66
View File
@@ -1,7 +1,9 @@
"""
Component
Abstract base class for Feed and Merge.
Component
|
Aggregate
Aggregate
|
/ \
Feed Merge
@@ -25,27 +27,11 @@ class Aggregate(Component):
- pull_socket
- feed_socket
Both use ``data_buffer`` for buffering.
Feed and Merge define these differently.
"""
def init(self):
self.sent_count = 0
self.received_count = 0
self.draining = False
self.ds_finished_counter = 0
# Depending on the size of this, might want to use a data
# structure with better asymptotics.
self.data_buffer = {}
# source_id -> integer count
self.sent_counters = Counter()
self.recv_counters = Counter()
@property
def get_id(self):
raise NotImplementedError
@property
def get_type(self):
return COMPONENT_TYPE.CONDUIT
@@ -54,10 +40,6 @@ class Aggregate(Component):
# Core Methods
# -------------
def open(self):
self.pull_socket = self.bind_data()
self.feed_socket = self.bind_feed()
def do_work(self):
# wait for synchronization reply from the host
socks = dict(self.poll.poll(self.heartbeat_timeout))
@@ -131,55 +113,16 @@ class Aggregate(Component):
if not (self.is_full() or self.draining):
return
# TODO: implement this in __iter__
event = self.next()
if event is not None:
if(event != None):
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_counters[event.source_id] += 1
self.sent_count += 1
def append(self, event):
"""
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[event.source_id].append(event)
self.recv_counters[event.source_id] += 1
self.received_count += 1
def next(self):
"""
Get the next message in chronological order.
"""
if not(self.is_full() or self.draining):
return
cur_source = None
earliest_source = None
earliest_event = None
#iterate over the queues of events from all sources
#(1 queue per datasource)
for events in self.data_buffer.itervalues():
if len(events) == 0:
continue
cur_source = events
first_in_list = events[0]
if first_in_list.dt == None:
#this is a filler event, discard
events.pop(0)
continue
if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt):
earliest_event = first_in_list
earliest_source = cur_source
if earliest_event != None:
return earliest_source.pop(0)
def is_full(self):
"""
Indicates whether the buffer has messages in buffer for
all un-DONE, blocking sources.
Indicates whether the buffer has messages in buffer for all
un-DONE, blocking sources.
"""
for source_id, events in self.data_buffer.iteritems():
if len(events) == 0:
+8 -94
View File
@@ -2,6 +2,7 @@ import logging
from collections import Counter
from zipline.core.component import Component
from zipline.components.aggregator import Aggregate
import zipline.protocol as zp
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
@@ -9,7 +10,7 @@ from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
LOGGER = logging.getLogger('ZiplineLogger')
class Feed(Component):
class Feed(Aggregate):
"""
Connects to N PULL sockets, publishing all messages received to a
PUB socket. Published messages are guaranteed to be in chronological
@@ -35,71 +36,17 @@ class Feed(Component):
def get_id(self):
return "FEED"
@property
def get_type(self):
return COMPONENT_TYPE.CONDUIT
# -------------
# Core Methods
# -------------
# -------
# Sockets
# -------
def open(self):
self.pull_socket = self.bind_data()
self.feed_socket = self.bind_feed()
def do_work(self):
# wait for synchronization reply from the host
socks = dict(self.poll.poll(self.heartbeat_timeout))
# TODO: Abstract this out, maybe on base component
if self.control_in in socks and socks[self.control_in] == self.zmq.POLLIN:
msg = self.control_in.recv()
event, payload = CONTROL_UNFRAME(msg)
# -- Heartbeat --
if event == CONTROL_PROTOCOL.HEARTBEAT:
# Heart outgoing
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.OK,
payload
)
self.control_out.send(heartbeat_frame)
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.signal_done()
self.shutdown()
# -- Hard Kill --
elif event == CONTROL_PROTOCOL.KILL:
self.kill()
if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN:
message = self.pull_socket.recv()
if message == str(CONTROL_PROTOCOL.DONE):
self.ds_finished_counter += 1
if len(self.data_buffer) == self.ds_finished_counter:
#drain any remaining messages in the buffer
LOGGER.debug("draining feed")
self.drain()
self.signal_done()
else:
try:
event = self.unframe(message)
# deserialization error
except zp.INVALID_DATASOURCE_FRAME as exc:
return self.signal_exception(exc)
try:
self.append(event)
self.send_next()
# Invalid message
except zp.INVALID_DATASOURCE_FRAME as exc:
return self.signal_exception(exc)
# -------------
# Core Methods
# -------------
def unframe(self, msg):
return zp.DATASOURCE_UNFRAME(msg)
@@ -169,36 +116,3 @@ class Feed(Component):
if earliest_event != None:
return earliest_source.pop(0)
def is_full(self):
"""
Indicates whether the buffer has messages in buffer for
all un-DONE, blocking sources.
"""
for source_id, events in self.data_buffer.iteritems():
if len(events) == 0:
return False
return True
def pending_messages(self):
"""
Returns the count of all events from all sources in the
buffer.
"""
total = 0
for events in self.data_buffer.values():
total += len(events)
return total
def add_source(self, source_id):
"""
Add a data source to the buffer.
"""
self.data_buffer[source_id] = []
def __len__(self):
"""
Buffer's length is same as internal map holding separate sorted
arrays of events keyed by source id.
"""
return len(self.data_buffer)
+19 -10
View File
@@ -2,10 +2,11 @@ from feed import Feed
import zipline.protocol as zp
from zipline.protocol import COMPONENT_TYPE
from zipline.components.aggregator import Aggregate
from collections import Counter
class Merge(Feed):
class Merge(Aggregate):
"""
Merges multiple streams of events into single messages.
"""
@@ -27,14 +28,28 @@ class Merge(Feed):
def get_id(self):
return "MERGE"
@property
def get_type(self):
return COMPONENT_TYPE.CONDUIT
# -------
# Sockets
# -------
def open(self):
self.pull_socket = self.bind_merge()
self.feed_socket = self.bind_result()
# -------
# Framing
# -------
def unframe(self, msg):
return zp.TRANSFORM_UNFRAME(msg)
def frame(self, event):
return zp.MERGE_FRAME(event)
# ---------
# Data Flow
# ---------
def next(self):
"""Get the next merged message from the feed buffer."""
if not (self.is_full() or self.draining):
@@ -53,12 +68,6 @@ class Merge(Feed):
result.merge(cur)
return result
def unframe(self, msg):
return zp.TRANSFORM_UNFRAME(msg)
def frame(self, event):
return zp.MERGE_FRAME(event)
def append(self, event):
"""
:param event: a ndict with one entry. key is the name of the