diff --git a/zipline/components/aggregator.py b/zipline/components/aggregator.py index d0b18208..5aadb7e8 100644 --- a/zipline/components/aggregator.py +++ b/zipline/components/aggregator.py @@ -1,7 +1,9 @@ """ - Component +Abstract base class for Feed and Merge. + + Component | - Aggregate + Aggregate | / \ Feed Merge @@ -25,27 +27,11 @@ class Aggregate(Component): - pull_socket - feed_socket + Both use ``data_buffer`` for buffering. + Feed and Merge define these differently. """ - def init(self): - self.sent_count = 0 - self.received_count = 0 - self.draining = False - self.ds_finished_counter = 0 - - # Depending on the size of this, might want to use a data - # structure with better asymptotics. - self.data_buffer = {} - - # source_id -> integer count - self.sent_counters = Counter() - self.recv_counters = Counter() - - @property - def get_id(self): - raise NotImplementedError - @property def get_type(self): return COMPONENT_TYPE.CONDUIT @@ -54,10 +40,6 @@ class Aggregate(Component): # Core Methods # ------------- - def open(self): - self.pull_socket = self.bind_data() - self.feed_socket = self.bind_feed() - def do_work(self): # wait for synchronization reply from the host socks = dict(self.poll.poll(self.heartbeat_timeout)) @@ -131,55 +113,16 @@ class Aggregate(Component): if not (self.is_full() or self.draining): return - # TODO: implement this in __iter__ event = self.next() - if event is not None: + if(event != None): self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK) self.sent_counters[event.source_id] += 1 self.sent_count += 1 - def append(self, event): - """ - Add an event to the buffer for the source specified by - source_id. - """ - self.data_buffer[event.source_id].append(event) - self.recv_counters[event.source_id] += 1 - self.received_count += 1 - - def next(self): - """ - Get the next message in chronological order. - """ - if not(self.is_full() or self.draining): - return - - cur_source = None - earliest_source = None - earliest_event = None - #iterate over the queues of events from all sources - #(1 queue per datasource) - for events in self.data_buffer.itervalues(): - if len(events) == 0: - continue - cur_source = events - first_in_list = events[0] - if first_in_list.dt == None: - #this is a filler event, discard - events.pop(0) - continue - - if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt): - earliest_event = first_in_list - earliest_source = cur_source - - if earliest_event != None: - return earliest_source.pop(0) - def is_full(self): """ - Indicates whether the buffer has messages in buffer for - all un-DONE, blocking sources. + Indicates whether the buffer has messages in buffer for all + un-DONE, blocking sources. """ for source_id, events in self.data_buffer.iteritems(): if len(events) == 0: diff --git a/zipline/components/feed.py b/zipline/components/feed.py index 75847e0d..72ad3f9a 100644 --- a/zipline/components/feed.py +++ b/zipline/components/feed.py @@ -2,6 +2,7 @@ import logging from collections import Counter from zipline.core.component import Component +from zipline.components.aggregator import Aggregate import zipline.protocol as zp from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \ @@ -9,7 +10,7 @@ from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \ LOGGER = logging.getLogger('ZiplineLogger') -class Feed(Component): +class Feed(Aggregate): """ Connects to N PULL sockets, publishing all messages received to a PUB socket. Published messages are guaranteed to be in chronological @@ -35,71 +36,17 @@ class Feed(Component): def get_id(self): return "FEED" - @property - def get_type(self): - return COMPONENT_TYPE.CONDUIT - - # ------------- - # Core Methods - # ------------- + # ------- + # Sockets + # ------- def open(self): self.pull_socket = self.bind_data() self.feed_socket = self.bind_feed() - def do_work(self): - # wait for synchronization reply from the host - socks = dict(self.poll.poll(self.heartbeat_timeout)) - - # TODO: Abstract this out, maybe on base component - if self.control_in in socks and socks[self.control_in] == self.zmq.POLLIN: - msg = self.control_in.recv() - event, payload = CONTROL_UNFRAME(msg) - - # -- Heartbeat -- - if event == CONTROL_PROTOCOL.HEARTBEAT: - # Heart outgoing - heartbeat_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.OK, - payload - ) - self.control_out.send(heartbeat_frame) - - # -- Soft Kill -- - elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.signal_done() - self.shutdown() - - # -- Hard Kill -- - elif event == CONTROL_PROTOCOL.KILL: - self.kill() - - - if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN: - message = self.pull_socket.recv() - - if message == str(CONTROL_PROTOCOL.DONE): - self.ds_finished_counter += 1 - - if len(self.data_buffer) == self.ds_finished_counter: - #drain any remaining messages in the buffer - LOGGER.debug("draining feed") - self.drain() - self.signal_done() - else: - try: - event = self.unframe(message) - # deserialization error - except zp.INVALID_DATASOURCE_FRAME as exc: - return self.signal_exception(exc) - - try: - self.append(event) - self.send_next() - - # Invalid message - except zp.INVALID_DATASOURCE_FRAME as exc: - return self.signal_exception(exc) + # ------------- + # Core Methods + # ------------- def unframe(self, msg): return zp.DATASOURCE_UNFRAME(msg) @@ -169,36 +116,3 @@ class Feed(Component): if earliest_event != None: return earliest_source.pop(0) - - def is_full(self): - """ - Indicates whether the buffer has messages in buffer for - all un-DONE, blocking sources. - """ - for source_id, events in self.data_buffer.iteritems(): - if len(events) == 0: - return False - return True - - def pending_messages(self): - """ - Returns the count of all events from all sources in the - buffer. - """ - total = 0 - for events in self.data_buffer.values(): - total += len(events) - return total - - def add_source(self, source_id): - """ - Add a data source to the buffer. - """ - self.data_buffer[source_id] = [] - - def __len__(self): - """ - Buffer's length is same as internal map holding separate sorted - arrays of events keyed by source id. - """ - return len(self.data_buffer) diff --git a/zipline/components/merge.py b/zipline/components/merge.py index 0ef29bc2..2ad64b4a 100644 --- a/zipline/components/merge.py +++ b/zipline/components/merge.py @@ -2,10 +2,11 @@ from feed import Feed import zipline.protocol as zp from zipline.protocol import COMPONENT_TYPE +from zipline.components.aggregator import Aggregate from collections import Counter -class Merge(Feed): +class Merge(Aggregate): """ Merges multiple streams of events into single messages. """ @@ -27,14 +28,28 @@ class Merge(Feed): def get_id(self): return "MERGE" - @property - def get_type(self): - return COMPONENT_TYPE.CONDUIT + # ------- + # Sockets + # ------- def open(self): self.pull_socket = self.bind_merge() self.feed_socket = self.bind_result() + # ------- + # Framing + # ------- + + def unframe(self, msg): + return zp.TRANSFORM_UNFRAME(msg) + + def frame(self, event): + return zp.MERGE_FRAME(event) + + # --------- + # Data Flow + # --------- + def next(self): """Get the next merged message from the feed buffer.""" if not (self.is_full() or self.draining): @@ -53,12 +68,6 @@ class Merge(Feed): result.merge(cur) return result - def unframe(self, msg): - return zp.TRANSFORM_UNFRAME(msg) - - def frame(self, event): - return zp.MERGE_FRAME(event) - def append(self, event): """ :param event: a ndict with one entry. key is the name of the