From e32adba72dc4ab9fab3950c4c39c3ec382f2be38 Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sat, 28 Jul 2012 00:13:48 -0400 Subject: [PATCH 1/7] gen-style SpecificEquityTrades done --- zipline/gens/tradegens.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 24ece45b..2b2598e8 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -1,5 +1,5 @@ import random -from itertools import chain, repeat, cycle, ifilter +from itertools import chain, repeat, cycle, ifilter, izip from datetime import datetime, timedelta from zipline.utils.factory import create_trade, create_trade @@ -20,7 +20,7 @@ def mock_volumes(n, rand = False): for readability.""" return mock_prices(n, rand) -def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = None): +def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = None): """Returns the first n events of event_list if specified. Otherwise generates a sensible stream of events.""" @@ -28,9 +28,9 @@ def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = Non unfiltered = (event for event in event_list) else: - dates = date_gen(n = n) - prices = mock_prices(n) - volumes = mock_volumes(n) + dates = date_gen(n = count) + prices = mock_prices(count) + volumes = mock_volumes(count) sids = cycle(iter(sids)) arg_gen = izip(sids, prices, volumes, dates) @@ -42,3 +42,8 @@ def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = Non filtered = unfiltered return filtered + +if __name__ == "__main__": + + import nose.tools; nose.tools.set_trace() + trades = SpecificEquityTrades() From 1f40684566211a93b3c362acf1b8db87f3575880 Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sat, 28 Jul 2012 00:17:28 -0400 Subject: [PATCH 2/7] patch for filter logic --- zipline/gens/tradegens.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 2b2598e8..6404c403 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -37,7 +37,7 @@ def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = unfiltered = (create_trade(*args) for args in arg_gen) if filter: - filtered = ifilter(lambda event: event.sid in filter) + filtered = ifilter(lambda event: event.sid in filter, unfiltered) else: filtered = unfiltered @@ -46,4 +46,4 @@ def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = if __name__ == "__main__": import nose.tools; nose.tools.set_trace() - trades = SpecificEquityTrades() + trades = SpecificEquityTrades(filter = [1]) From 71cc67e1239e05a7860a62117a29efb9224e7c87 Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sat, 28 Jul 2012 01:03:09 -0400 Subject: [PATCH 3/7] generator random equity trades --- zipline/finance/sources.py | 2 +- zipline/gens/composites.py | 2 +- zipline/gens/tradegens.py | 43 ++++++++++++++++++++++++++++++++------ zipline/gens/utils.py | 2 +- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/zipline/finance/sources.py b/zipline/finance/sources.py index 0aba2186..bfa9e86f 100644 --- a/zipline/finance/sources.py +++ b/zipline/finance/sources.py @@ -90,7 +90,7 @@ class RandomEquityTrades(TradeDataSource): class SpecificEquityTrades(TradeDataSource): """ - Generates a random stream of trades for testing. + Generates a non-random stream of trades for testing. """ def init(self, event_list): diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 1b710212..ac24c7ce 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -9,7 +9,7 @@ from zipline.gens.transform def PreTransformLayer(sources): """A generator that takes a list of sources and runs their output through a FeedGen.""" - not_finished = len_ + not_finished = len #NOT DONE while not_finished: diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 6404c403..619721e4 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -16,11 +16,24 @@ def mock_prices(n, rand = False): return (float(i % 11) for i in xrange(1,n+1)) def mock_volumes(n, rand = False): - """Does the same as mock_prices. Different function name - for readability.""" - return mock_prices(n, rand) + """Utility to generate a set of volumes. By default cycles + through values from 100 to 1000, incrementing by 50. Optional + flag to give random values between 100 and 1000. """ + + if rand: + return (random.randrange(100, 1000) for i in xrange(n)) + else: + return ((i * 50)%900 + 100 for i in xrange(n)) + +def fuzzy_dates(count = 500): + """Add +-10 seconds to each event from a date_gen. Note that + this still guarantees sorting, since the default is minute separation + of events.""" + for date in date_gen(n = count): + yield date + timedelta(seconds = random.randint(-10, 10)) def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = None): + """Returns the first n events of event_list if specified. Otherwise generates a sensible stream of events.""" @@ -43,7 +56,25 @@ def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = return filtered -if __name__ == "__main__": +def RandomEquityTrades(count = 500, sids = [1,2], filter = None): + dates = fuzzy_dates(500) + prices = mock_prices(500, rand = True) + volumes = mock_volumes(500, rand = True) + sids = cycle(iter(sids)) + + arg_gen = izip(sids, prices, volumes, dates) - import nose.tools; nose.tools.set_trace() - trades = SpecificEquityTrades(filter = [1]) + unfiltered = (create_trade(*args) for args in arg_gen) + + if filter: + filtered = ifilter(lambda event: event.sid in filter, unfiltered) + else: + filtered = unfiltered + return filtered + +if __name__ == "__main__": + rand = RandomEquityTrades() + pass +# x = mock_volumes(500) +# import nose.tools; nose.tools.set_trace() +# trades = SpecificEquityTrades(filter = [1]) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index efe81a14..1060cd24 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -17,7 +17,7 @@ def mock_raw_event(sid, dt): return event def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100): - return (start + i * delta for i in xrange(n)) + return (start + (i * delta) for i in xrange(n)) def alternate(g1, g2): for e1, e2 in izip_longest(g1, g2): From 3621934a2843f7d6d3ec26b5d010e8d36b16670f Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sat, 28 Jul 2012 18:24:57 -0400 Subject: [PATCH 4/7] better variable names and PreTransformLayer --- zipline/gens/composites.py | 23 ++++++----- zipline/gens/feed.py | 5 ++- zipline/gens/test_feed.py | 48 +++++++++++++++++----- zipline/gens/test_mongods.py | 4 +- zipline/gens/tradegens.py | 78 ++++++++++++++++++++++-------------- zipline/gens/utils.py | 15 +++++-- zipline/utils/factory.py | 4 +- 7 files changed, 119 insertions(+), 58 deletions(-) diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index ac24c7ce..1f68d58e 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,17 +1,18 @@ - +from zipline.gens.utils import roundrobin from zipline.gens.feed import FeedGen -from zipline.gens.tradegen import SpecificEquityTrades -from zipline.gens.transform +def PreTransformLayer(sources, source_ids): + """ + A generator that takes a tuple of sources and a list ids, piping + their output into a feed_gen. + """ + stream_in = roundrobin(*sources) + return FeedGen(stream_in, source_ids) +def TransformLayer(feed_stream, tnfms): + """ """ + pass -def PreTransformLayer(sources): - """A generator that takes a list of sources and runs their output - through a FeedGen.""" - not_finished = len #NOT DONE - - while not_finished: - - + diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index e7498d5c..93e03371 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -13,7 +13,8 @@ from collections import deque, defaultdict from zipline import ndict from zipline.gens.utils import hash_args, assert_datasource_protocol, \ - assert_trade_protocol, assert_datasource_unframe_protocol + assert_trade_protocol, assert_datasource_unframe_protocol, \ + assert_feed_protocol import zipline.protocol as zp @@ -49,7 +50,7 @@ def FeedGen(stream_in, source_ids): while full(sources) and not done(sources): message = pop_oldest(sources) - assert feed_protocol(message) + assert_feed_protocol(message) yield message # We should have only a done message left in each queue. diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 5e632f9d..9b77f9c8 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -13,8 +13,10 @@ from collections import deque from zipline import ndict from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\ pop_oldest -from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ - assert_trade_protocol, date_gen, alternate +from zipline.gens.utils import hash_args, assert_datasource_protocol,\ + assert_trade_protocol, alternate +from zipline.gens.tradegens import date_gen, SpecificEquityTrades +from zipline.gens.composites import PreTransformLayer import zipline.protocol as zp @@ -98,12 +100,11 @@ class FeedGenTestCase(TestCase): l = list(feed_gen) assert l == expected - def test_single_source(self): source_ids = ['a'] # 100 events, increasing by a minute at a time. type = zp.DATASOURCE_TYPE.TRADE - dates = list(date_gen(n = 1)) + dates = list(date_gen(count = 100)) dates.append("DONE") # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] @@ -125,7 +126,7 @@ class FeedGenTestCase(TestCase): # Set up source 'a'. Outputs 20 events with 2 minute deltas. delta_a = timedelta(minutes = 2) - dates_a = list(date_gen(delta = delta_a, n = 20)) + dates_a = list(date_gen(delta = delta_a, count = 20)) dates_a.append("DONE") events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) @@ -133,7 +134,7 @@ class FeedGenTestCase(TestCase): # Set up source 'b'. Outputs 10 events with 1 minute deltas. delta_b = timedelta(minutes = 1) - dates_b = list(date_gen(delta = delta_b, n = 10)) + dates_b = list(date_gen(delta = delta_b, count = 10)) dates_b.append("DONE") events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) @@ -153,12 +154,41 @@ class FeedGenTestCase(TestCase): sequential = chain(iter(events_a), iter(events_b)) self.run_FeedGen(sequential, expected, source_ids) - def test_with_specific_equity(self): - - + def test_full_feed_layer(self): + filter = [1,2] + source_a = SpecificEquityTrades(sids = [1,2,3,4], + start = datetime(2012,6,6,0), + delta = timedelta(minutes=1), + filter = filter + ) + id_a = "SpecificEquityTradesd175237b28d2f52df208c97cf4af896e" + + # Change the internal sid list to give us a different hash. + source_b = SpecificEquityTrades(sids = [1,2,3,5], + start = datetime(2012,6,6,0), + delta = timedelta(minutes=1), + filter = filter + ) + id_b = 'SpecificEquityTrades2bf2c2d6d01d4dbfc0b2818438ea8151' + + # Change the internal sid list to give us a different hash. + source_c = SpecificEquityTrades(sids = [1,2,3,6], + start = datetime(2012,6,6,0), + delta = timedelta(minutes=1), + filter = filter + ) + id_c = 'SpecificEquityTrades16f7437db2d14e5373ef20025f49a3fe' + + sources = (source_a, source_b, source_c) + source_ids = [id_a, id_b, id_c] + import nose.tools; nose.tools.set_trace() + feed_out = PreTransformLayer(sources, source_ids) + for i in feed_out: + print i + def mock_data_unframe(source_id, dt, type): event = ndict() event.source_id = source_id diff --git a/zipline/gens/test_mongods.py b/zipline/gens/test_mongods.py index ea19d90e..d9b8dbe5 100644 --- a/zipline/gens/test_mongods.py +++ b/zipline/gens/test_mongods.py @@ -10,7 +10,7 @@ from itertools import izip, izip_longest from datetime import datetime, timedelta from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen -from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ +from zipline.gens.utils import hash_args, assert_datasource_protocol,\ assert_trade_protocol, mock_raw_event import zipline.protocol as zp @@ -107,7 +107,7 @@ class TestMongoDataGenerator(TestCase): for field in iter(['sid', 'dt', 'price', 'volume']): assert db[field] == expected[field] - # Expected output of stringify_args: + # Expected output of hash_args: assert db['source_id'] == \ 'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc' diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 619721e4..8327aa63 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -2,53 +2,76 @@ import random from itertools import chain, repeat, cycle, ifilter, izip from datetime import datetime, timedelta -from zipline.utils.factory import create_trade, create_trade -from zipline.gens.utils import date_gen +from zipline.utils.factory import create_trade +from zipline.gens.utils import hash_args -def mock_prices(n, rand = False): - """Utility to generate a set of prices. By default - cycles through values from 0.0 to 10.0 n times. Optional - flag to give random values between 0.0 and 10.0""" +def date_gen(start = datetime(2012, 6, 6, 0), + delta = timedelta(minutes = 1), + count = 100): + """ + Utility to generate a stream of dates. + """ + return (start + (i * delta) for i in xrange(count)) + +def mock_prices(count, rand = False): + """ + Utility to generate a stream of mock prices. By default + cycles through values from 0.0 to 10.0, n times. Optional + flag to give random values between 0.0 and 10.0 + """ if rand: - return (random.uniform(0.0, 10.0) for i in xrange(n)) + return (random.uniform(0.0, 10.0) for i in xrange(count)) else: - return (float(i % 11) for i in xrange(1,n+1)) + return (float(i % 11) for i in xrange(1,count+1)) -def mock_volumes(n, rand = False): - """Utility to generate a set of volumes. By default cycles +def mock_volumes(count, rand = False): + """ + Utility to generate a set of volumes. By default cycles through values from 100 to 1000, incrementing by 50. Optional - flag to give random values between 100 and 1000. """ - + flag to give random values between 100 and 1000. + """ if rand: - return (random.randrange(100, 1000) for i in xrange(n)) + return (random.randrange(100, 1000) for i in xrange(count)) else: - return ((i * 50)%900 + 100 for i in xrange(n)) + return ((i * 50)%900 + 100 for i in xrange(count)) def fuzzy_dates(count = 500): - """Add +-10 seconds to each event from a date_gen. Note that - this still guarantees sorting, since the default is minute separation - of events.""" - for date in date_gen(n = count): + """ + Add +-10 seconds to each event from a date_gen. Note that this + still guarantees sorting, since the default on date_gen is minute + separation of events. + """ + for date in date_gen(count = count): yield date + timedelta(seconds = random.randint(-10, 10)) -def SpecificEquityTrades(count = 500, sids = [1, 2], event_list = None, filter = None): - - """Returns the first n events of event_list if specified. - Otherwise generates a sensible stream of events.""" - +def SpecificEquityTrades(count = 500, + sids = [1, 2], + start = datetime(2012, 6, 6, 0), + delta = timedelta(minutes = 1), + event_list = None, + filter = None): + """ + Yields all events in event_list that match the given sid_filter. + If no event_list is specified, generates an internal stream of events + to filter. Returns all events if filter is None. + """ + arg_string = hash_args(count, sids, start, delta, filter) + namestring = "SpecificEquityTrades" + arg_string + if event_list: unfiltered = (event for event in event_list) else: - dates = date_gen(n = count) + dates = date_gen(count = count, start = start, delta = delta) prices = mock_prices(count) volumes = mock_volumes(count) sids = cycle(iter(sids)) arg_gen = izip(sids, prices, volumes, dates) - unfiltered = (create_trade(*args) for args in arg_gen) + unfiltered = (create_trade(*args, source_id = namestring) + for args in arg_gen) if filter: filtered = ifilter(lambda event: event.sid in filter, unfiltered) else: @@ -72,9 +95,6 @@ def RandomEquityTrades(count = 500, sids = [1,2], filter = None): filtered = unfiltered return filtered -if __name__ == "__main__": - rand = RandomEquityTrades() - pass -# x = mock_volumes(500) +# if __name__ == "__main__": # import nose.tools; nose.tools.set_trace() # trades = SpecificEquityTrades(filter = [1]) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 1060cd24..8966c638 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -16,16 +16,25 @@ def mock_raw_event(sid, dt): } return event -def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100): - return (start + (i * delta) for i in xrange(n)) - def alternate(g1, g2): + """Specialized version of roundrobin for just 2 generators.""" for e1, e2 in izip_longest(g1, g2): if e1 != None: yield e1 if e2 != None: yield e2 +def roundrobin(*args): + """ + Takes N generators, pulling one element off each until all inputs + are empty. + """ + for elem_tuple in izip_longest(*args): + for value in elem_tuple: + if value != None: + yield value + + def hash_args(*args, **kwargs): """Define a unique string for any set of representable args.""" arg_string = '_'.join([str(arg) for arg in args]) diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 23c2eb3e..db440891 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -69,9 +69,9 @@ def create_trading_environment(year=2006): return trading_environment -def create_trade(sid, price, amount, datetime): +def create_trade(sid, price, amount, datetime, source_id = "test_factory"): row = zp.ndict({ - 'source_id' : "test_factory", + 'source_id' : source_id, 'type' : zp.DATASOURCE_TYPE.TRADE, 'sid' : sid, 'dt' : datetime, From e048e8bc352631747dcc757b70631708d72a686b Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sat, 28 Jul 2012 19:04:50 -0400 Subject: [PATCH 5/7] added done message to SpecificEquity --- zipline/gens/feed.py | 3 ++- zipline/gens/test_feed.py | 4 +--- zipline/gens/tradegens.py | 6 ++++-- zipline/gens/utils.py | 3 +++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index 93e03371..0528d7ee 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -52,7 +52,8 @@ def FeedGen(stream_in, source_ids): message = pop_oldest(sources) assert_feed_protocol(message) yield message - + + import nose.tools; nose.tools.set_trace() # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 9b77f9c8..12cf4bc5 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -184,10 +184,8 @@ class FeedGenTestCase(TestCase): sources = (source_a, source_b, source_c) source_ids = [id_a, id_b, id_c] - import nose.tools; nose.tools.set_trace() feed_out = PreTransformLayer(sources, source_ids) - for i in feed_out: - print i + l = list(feed_out) def mock_data_unframe(source_id, dt, type): event = ndict() diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 8327aa63..35100f8e 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -3,7 +3,7 @@ from itertools import chain, repeat, cycle, ifilter, izip from datetime import datetime, timedelta from zipline.utils.factory import create_trade -from zipline.gens.utils import hash_args +from zipline.gens.utils import hash_args, mock_done def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), @@ -77,7 +77,9 @@ def SpecificEquityTrades(count = 500, else: filtered = unfiltered - return filtered + # Add a done message to the end of the stream. + out = chain(filtered, iter([mock_done(namestring)])) + return out def RandomEquityTrades(count = 500, sids = [1,2], filter = None): dates = fuzzy_dates(500) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 8966c638..8ed6ea9d 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -16,6 +16,9 @@ def mock_raw_event(sid, dt): } return event +def mock_done(source_id): + return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0}) + def alternate(g1, g2): """Specialized version of roundrobin for just 2 generators.""" for e1, e2 in izip_longest(g1, g2): From fe1740a3cef30dfa34e37a175bc2401f37ad95aa Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Sun, 29 Jul 2012 19:56:10 -0400 Subject: [PATCH 6/7] updates for transforms --- zipline/gens/composites.py | 64 ++++++++++++++++++++++++++++++++------ zipline/gens/feed.py | 1 - zipline/gens/test_feed.py | 62 ++++++++++++++++++------------------ zipline/gens/transform.py | 61 ++++++++++++++++++++++-------------- zipline/gens/utils.py | 6 ++-- 5 files changed, 124 insertions(+), 70 deletions(-) diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 1f68d58e..ba0538b6 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,18 +1,62 @@ -from zipline.gens.utils import roundrobin +from itertools import tee + +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.utils import roundrobin, hash_args from zipline.gens.feed import FeedGen +from zipline.gens.merge import MergeGen +def PreTransformLayer(sources, source_args, source_kwargs): + """ + Takes a list of generator functions, a list of tuples of positional arguments, + and a list of dictionaries of keyword arguments. Packages up all arguments + and passes them into a FeedGen. + """ + assert len(sources) == len(source_args) == len(source_kwargs) + # Package up sources and arguments. + arg_bundles = zip(sources, source_args, source_kwargs) + # Calculate namestring hashes to pass to FeedGen. + namestrings = [source.__name__ + hash_args(*args, **kwargs) + for source, args, kwargs in arg_bundles] + # Pass each source its arguments. + initialized = tuple(source(*args, **kwargs) + for source, args, kwargs in arg_bundles) -def PreTransformLayer(sources, source_ids): - """ - A generator that takes a tuple of sources and a list ids, piping - their output into a feed_gen. - """ - stream_in = roundrobin(*sources) + stream_in = roundrobin(*initialized) return FeedGen(stream_in, source_ids) -def TransformLayer(feed_stream, tnfms): - """ """ - pass +def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): + """ + A generator that takes the expected output of a FeedGen, pipes it + through a given set of transforms, and runs the results throught a + MergeGen to output a unified stream. tnfms should be a list of + pointers to generator functions. tnfm_args should be a list of + tuples, representing the arguments to be passed to each transform. + tnfm_kwargs should be a list of dictionaries representing keyword + arguments to each transform. + """ + # We should have as many sets of args as we have transforms. + assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) + # Create a copy of the stream for each transform. + split = tee(feed_stream, len(tnfms)) + + # Package each stream copy with a transform and set of args. Use a list + # so that we can re-use this for calculating hashes. + bundles = zip(split, iter(tnfms), iter(tnfm_args), iter(tnfm_kwargs)) + + # Convert the argument bundles into a tuple of transform objects. + transformed = tuple((tnfm(stream, *args, **kwargs) + for stream, tnfm, args, kwargs in iter(bundles))) + + # Roundrobin the outputs of our transforms to create a single flat stream. + to_merge = roundrobin(*transformed) + + merged = MergeGen() + + +if __name__ == "__main__": + + source = SpecificEquityTrades() + diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index 0528d7ee..284023a0 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -53,7 +53,6 @@ def FeedGen(stream_in, source_ids): assert_feed_protocol(message) yield message - import nose.tools; nose.tools.set_trace() # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 12cf4bc5..7c486e33 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -153,39 +153,41 @@ class FeedGenTestCase(TestCase): sequential = chain(iter(events_a), iter(events_b)) self.run_FeedGen(sequential, expected, source_ids) - + def test_full_feed_layer(self): filter = [1,2] + #Set up source a. + args_a = tuple() + kwargs_a = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + #Set up source b. + args_b = tuple() + kwargs_b = {'sids' : [1,2,3,5], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + #Set up source c. + args_c = tuple() + kwargs_c = {'sids' : [1,2,3,5], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } - source_a = SpecificEquityTrades(sids = [1,2,3,4], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) - id_a = "SpecificEquityTradesd175237b28d2f52df208c97cf4af896e" - - # Change the internal sid list to give us a different hash. - source_b = SpecificEquityTrades(sids = [1,2,3,5], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) + sources = tuple(SpecificEquityTrades) * 3 + source_args = (args_a, args_b, args_c) + source_kwargs = (kwargs_a, kwargs_b, kwargs_c) - id_b = 'SpecificEquityTrades2bf2c2d6d01d4dbfc0b2818438ea8151' - - # Change the internal sid list to give us a different hash. - source_c = SpecificEquityTrades(sids = [1,2,3,6], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) - id_c = 'SpecificEquityTrades16f7437db2d14e5373ef20025f49a3fe' - - sources = (source_a, source_b, source_c) - source_ids = [id_a, id_b, id_c] + feed_out = PreTransformLayer(sources, source_args, source_kwargs) + to_list = list(feed_out) + copy = to_list[:] + expected = sorted(copy, compare_by_dt_source_id) - feed_out = PreTransformLayer(sources, source_ids) - l = list(feed_out) + assert to_list == expected def mock_data_unframe(source_id, dt, type): event = ndict() @@ -210,7 +212,3 @@ def compare_by_dt_source_id(x,y): else: return 0 - - - - diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index bdd85ee9..666fdd46 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -14,8 +14,9 @@ from numbers import Number from itertools import izip from zipline import ndict -from zipline.gens.utils import hash_args, date_gen -from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol +from zipline.gens.tradegens import date_gen +from zipline.gens.utils import assert_feed_unframe_protocol, \ + assert_transform_protocol, hash_args import zipline.protocol as zp @@ -43,6 +44,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ # TODO: Distinguish between functions and classes in hash_args. + # As implemented we will get assertion errors if a function and + # stateful class have the same name, which may or may not be + # what we want. + namestring = fun.__name__ + hash_args(*args, **kwargs) for message in stream_in: @@ -75,7 +80,10 @@ def MovingAverageTransformGen(stream_in, days, fields): Generator that uses the MovingAverage state class to calculate a moving average for all stocks over a specified number of days. """ - return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields) + return StatefulTransformGen(stream_in, + MovingAverage, + timedelta(days=days), + fields) class MovingAverage(object): """ @@ -91,7 +99,7 @@ class MovingAverage(object): # No way to pass arguments to the defaultdict factory, so we # need to define a method to generate the correct EventWindows. self.sid_windows = defaultdict(self.create_window) - + def create_window(self): """Factory method for self.sid_windows.""" return EventWindow(self.delta, self.fields) @@ -105,12 +113,16 @@ class MovingAverage(object): assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event assert event.has_key('sid'), "No sid in MovingAverage: %s" % event + output = ndict({'sid': event.sid}) # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) + averages = window.get_averages() - return window.get_averages() + # Return the calculated averages along with + output.merge(averages) + return output class EventWindow(object): """ @@ -144,7 +156,7 @@ class EventWindow(object): # newest oldest # | | # V V - + while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: # popleft removes and returns ticks[0] popped = self.ticks.popleft() @@ -168,6 +180,7 @@ class EventWindow(object): Return an ndict of all our tracked averages. """ out = ndict() + # out.ticks = len(self.ticks) for field in self.fields: out[field] = self.average(field) return out @@ -186,26 +199,26 @@ class EventWindow(object): assert isinstance(event[field], Number), \ "Got %s for %s in EventWindow" % (event[field], field) -if __name__ == "__main__": +# if __name__ == "__main__": - def make_event(**kwargs): - e = ndict() - for key, value in kwargs.iteritems(): - e[key] = value - return e +# def make_event(**kwargs): +# e = ndict() +# for key, value in kwargs.iteritems(): +# e[key] = value +# return e - dates = date_gen(delta = timedelta(hours = 12)) - events = ( - make_event( - sid = 'foo', price = random.random(), - dt = date, - type = zp.DATASOURCE_TYPE.TRADE, - source_id = 'ds', - vol = i - ) - for date, i in izip(dates, xrange(100)) - ) +# dates = date_gen(delta = timedelta(hours = 12)) +# events = ( +# make_event( +# sid = 'foo', price = random.random(), +# dt = date, +# type = zp.DATASOURCE_TYPE.TRADE, +# source_id = 'ds', +# vol = i +# ) +# for date, i in izip(dates, xrange(100)) +# ) - gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) +# gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 8ed6ea9d..6150757b 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -85,14 +85,14 @@ def assert_feed_protocol(event): assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_feed_unframe_protocol(event): """Same as above.""" assert isinstance(event, ndict) + assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_transform_protocol(event): - pass + """Transforms should return an ndict to be merged by MergeGen.""" + assert isinstance(event, ndict) From 5d9bfe6b92319c4ff1d901d8c72a2de169dcc07a Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Mon, 30 Jul 2012 13:51:10 -0400 Subject: [PATCH 7/7] full sequencing system (minus done from xforms) --- zipline/gens/composites.py | 42 ++++++++++++++++++------- zipline/gens/feed.py | 7 +++-- zipline/gens/merge.py | 40 +++++++++++++----------- zipline/gens/test_feed.py | 45 ++++++++++++++++++++------- zipline/gens/tradegens.py | 64 ++++++++++++++++++++++++++------------ zipline/gens/transform.py | 38 ++++++++-------------- zipline/gens/utils.py | 5 +++ zipline/gens/zmq_gens.py | 16 ++++++++++ 8 files changed, 167 insertions(+), 90 deletions(-) create mode 100644 zipline/gens/zmq_gens.py diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index ba0538b6..5ee45de6 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,9 +1,11 @@ +import datetime from itertools import tee from zipline.gens.tradegens import SpecificEquityTrades from zipline.gens.utils import roundrobin, hash_args from zipline.gens.feed import FeedGen from zipline.gens.merge import MergeGen +from zipline.gens.transform import StatefulTransformGen def PreTransformLayer(sources, source_args, source_kwargs): """ @@ -14,15 +16,17 @@ def PreTransformLayer(sources, source_args, source_kwargs): assert len(sources) == len(source_args) == len(source_kwargs) # Package up sources and arguments. arg_bundles = zip(sources, source_args, source_kwargs) + # Calculate namestring hashes to pass to FeedGen. namestrings = [source.__name__ + hash_args(*args, **kwargs) for source, args, kwargs in arg_bundles] # Pass each source its arguments. initialized = tuple(source(*args, **kwargs) for source, args, kwargs in arg_bundles) - + stream_in = roundrobin(*initialized) - return FeedGen(stream_in, source_ids) + return FeedGen(stream_in, namestrings) + def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): """ @@ -34,6 +38,7 @@ def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): tnfm_kwargs should be a list of dictionaries representing keyword arguments to each transform. """ + # We should have as many sets of args as we have transforms. assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) @@ -42,21 +47,34 @@ def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): # Package each stream copy with a transform and set of args. Use a list # so that we can re-use this for calculating hashes. - bundles = zip(split, iter(tnfms), iter(tnfm_args), iter(tnfm_kwargs)) - - # Convert the argument bundles into a tuple of transform objects. - transformed = tuple((tnfm(stream, *args, **kwargs) - for stream, tnfm, args, kwargs in iter(bundles))) + bundles = zip(split, tnfms, tnfm_args, tnfm_kwargs) + + tnfm_gens = [StatefulTransformGen(stream, tnfm, *args, **kwargs) + for stream, tnfm, args, kwargs in bundles] + # Generate expected hashes for each transform + hashes = [tnfm.__name__ + hash_args(*args, **kwargs) + for _, tnfm, args, kwargs in bundles] + # Roundrobin the outputs of our transforms to create a single flat stream. - to_merge = roundrobin(*transformed) - - merged = MergeGen() - + to_merge = roundrobin(*tnfm_gens) + # Pipe the stream into MergeGen. + merged = MergeGen(to_merge, hashes) + return merged + if __name__ == "__main__": + from zipline.gens.transform import MovingAverage, Passthrough - source = SpecificEquityTrades() + import nose.tools; nose.tools.set_trace() + source = SpecificEquityTrades + feed_out = PreTransformLayer((source,), ((),), ({},)) + + transforms = [MovingAverage, Passthrough] + args = [(datetime.timedelta(days = 1), ['price']), ()] + kwargs = [{}, {}] + + tlayer = TransformLayer(feed_out, transforms, args, kwargs) diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index 284023a0..f029d0ee 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -39,7 +39,7 @@ def FeedGen(stream_in, source_ids): # Incoming messages should be the output of DATASOURCE_UNFRAME. assert_datasource_unframe_protocol(message), \ "Bad message in FeedGen: %s" % message - + # Only allow messages from sources we expect. assert message.source_id in sources, "Unexpected source: %s" % message @@ -61,8 +61,9 @@ def FeedGen(stream_in, source_ids): def full(sources): """ - Feed is full when every internal queue has at least one message. Note that - this include DONE messages, so done(sources) is True only if full(sources). + Feed is full when every internal queue has at least one + message. Note that this include DONE messages, so done(sources) is + True only if full(sources). """ assert isinstance(sources, dict) return all( (queue_is_full(source) for source in sources.itervalues()) ) diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 7c0a195c..52adcb84 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -13,7 +13,7 @@ from collections import deque, defaultdict from zipline import ndict from zipline.gens.utils import hash_args, assert_datasource_protocol, \ - assert_trade_protocol, assert_datasource_unframe_protocol + assert_trade_protocol, assert_datasource_unframe_protocol, assert_merge_protocol import zipline.protocol as zp @@ -25,44 +25,46 @@ def MergeGen(stream_in, tnfm_ids): and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. """ - - assert isinstance(source_ids, list) + assert isinstance(tnfm_ids, list) + # Set up an internal queue for each expected source. - sources = {} - for id in source_ids: - assert isinstance(id, basestring), "Bad source_id %s" % source_id - sources[id] = deque() + tnfms = {} + for id in tnfm_ids: + assert isinstance(id, basestring), "Bad source_id %s" % id + tnfms[id] = deque() # Process incoming streams. for message in stream_in: - assert isinstance(message, ndict), \ + assert isinstance(message, tuple), \ "Bad message in MergeGen: %s" %message - assert message.tnfm_id in tnfm_ids, \ - "Message from unexpected tnfm: %s, %s" % (message, tnfm_ids) + assert len(message) == 2 + id, value = message + assert id in tnfm_ids, \ + "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) + assert isinstance(value, ndict), "Bad message in MergeGen: %s" %message - assert message.has_key('value') - - source[message.tnfm_id].append(message) + tnfms[id].append(value) # Only pop messages when we have a pending message from # all datasources. Stop if all sources have signalled done. - while full(sources) and not done(sources): - message = merge_one(sources) - assert merge_protocol(message) + while full(tnfms) and not done(tnfms): + message = merge_one(tnfms) + assert_merge_protocol(tnfm_ids, message) yield message # We should have only a done message left in each queue. - for queue in sources.itervalues(): + for queue in tnfms.itervalues(): assert len(queue) == 1, "Bad queue in MergeGen on exit: %s" % queue assert queue[0].dt == "DONE", \ "Bad last message in MergeGen on exit: %s" % queue def merge_one(sources): output = ndict() - for queue in sources.itervalues(): - output.merge(queue.popleft()) + for key, queue in sources.iteritems(): + new_xform = ndict({key: queue.popleft()}) + output.merge(new_xform) return output diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 7c486e33..b5007b91 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -155,38 +155,61 @@ class FeedGenTestCase(TestCase): self.run_FeedGen(sequential, expected, source_ids) def test_full_feed_layer(self): + filter = [1,2] - #Set up source a. + #Set up source a. One hour between events. args_a = tuple() kwargs_a = {'sids' : [1,2,3,4], 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), + 'delta' : timedelta(hours = 1), 'filter' : filter } - #Set up source b. + #Set up source b. One day between events. args_b = tuple() - kwargs_b = {'sids' : [1,2,3,5], + kwargs_b = {'sids' : [1,2,3,4], 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), + 'delta' : timedelta(days = 1), 'filter' : filter } - #Set up source c. + #Set up source c. One minute between events. args_c = tuple() - kwargs_c = {'sids' : [1,2,3,5], + kwargs_c = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + # Set up source d. This should produce no events because the + # internal sids don't match the filter. + args_d = tuple() + kwargs_d = {'sids' : [3,4], 'start' : datetime(2012,6,6,0), 'delta' : timedelta(minutes = 1), 'filter' : filter } - sources = tuple(SpecificEquityTrades) * 3 - source_args = (args_a, args_b, args_c) - source_kwargs = (kwargs_a, kwargs_b, kwargs_c) + sources = (SpecificEquityTrades,) * 4 + source_args = (args_a, args_b, args_c, args_d) + source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) + # Generate our expected source_ids. + zip_args = zip(source_args, source_kwargs) + expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) + for args, kwargs in zip_args] + + # Pipe our sources into feed. feed_out = PreTransformLayer(sources, source_args, source_kwargs) + + # Read all the values from feed and assert that they arrive in + # the correct sorting with the expected hash values. to_list = list(feed_out) copy = to_list[:] + for e in to_list: + # All events should match one of our expected source_ids. + assert e.source_id in expected_ids + # But none of them should match source_d. + assert e.source_id != hash_args(*args_d, **kwargs_d) + expected = sorted(copy, compare_by_dt_source_id) - assert to_list == expected def mock_data_unframe(source_id, dt, type): diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 35100f8e..fb0b3f48 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -45,47 +45,71 @@ def fuzzy_dates(count = 500): for date in date_gen(count = count): yield date + timedelta(seconds = random.randint(-10, 10)) -def SpecificEquityTrades(count = 500, - sids = [1, 2], - start = datetime(2012, 6, 6, 0), - delta = timedelta(minutes = 1), - event_list = None, - filter = None): +def SpecificEquityTrades(*args, **config): """ Yields all events in event_list that match the given sid_filter. If no event_list is specified, generates an internal stream of events to filter. Returns all events if filter is None. """ - arg_string = hash_args(count, sids, start, delta, filter) - namestring = "SpecificEquityTrades" + arg_string + # We shouldn't get any positional arguments. + assert args == () + # Unpack config dictionary with default values. + count = config.get('count', 500) + sids = config.get('sids', [1, 2]) + start = config.get('start', datetime(2012, 6, 6, 0)) + delta = config.get('delta', timedelta(minutes = 1)) + + # Default to None for event_list and filter. + event_list = config.get('event_list') + filter = config.get('filter') + + arg_string = hash_args(*args, **config) + namestring = "SpecificEquityTrades" + arg_string + # If we have an event_list, ignore the other arguments and use the list. + # TODO: still append our namestring? if event_list: unfiltered = (event for event in event_list) - + + # Set up iterators for each expected field. else: dates = date_gen(count = count, start = start, delta = delta) prices = mock_prices(count) volumes = mock_volumes(count) - sids = cycle(iter(sids)) - + sids = cycle(sids) + + # Combine the iterators into a single iterator of arguments arg_gen = izip(sids, prices, volumes, dates) - + + # Convert argument packages into events. unfiltered = (create_trade(*args, source_id = namestring) for args in arg_gen) + + # If we specified a sid filter, filter out elements that don't match the filter. if filter: filtered = ifilter(lambda event: event.sid in filter, unfiltered) + + # Otherwise just use all events. else: filtered = unfiltered - # Add a done message to the end of the stream. - out = chain(filtered, iter([mock_done(namestring)])) - return out + # Add a done message to the end of the stream. For a live + # datasource this would be handled by the containing Component. + out = chain(filtered, [mock_done(namestring)]) + return out -def RandomEquityTrades(count = 500, sids = [1,2], filter = None): - dates = fuzzy_dates(500) - prices = mock_prices(500, rand = True) - volumes = mock_volumes(500, rand = True) - sids = cycle(iter(sids)) +def RandomEquityTrades(*args, **config): + # We shouldn't get any positional args. + assert args == () + + count = config.get('count', 500) + sids = config.get('sids', [1,2]) + filter = config.get('filter') + + dates = fuzzy_dates(count) + prices = mock_prices(count, rand = True) + volumes = mock_volumes(count, rand = True) + sids = cycle(sids) arg_gen = izip(sids, prices, volumes, dates) diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 666fdd46..137b9d69 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -20,21 +20,18 @@ from zipline.gens.utils import assert_feed_unframe_protocol, \ import zipline.protocol as zp -def PassthroughTransformGen(stream_in): - """Trivial transform for event forwarding.""" +class Passthrough(object): + """ + Trivial function for forwarding events. + """ + def __init__(self): + pass - # hash_args with no arguments is the same as: - # hasher = hashlib.md5() - # hasher.update(":"); - # hashlib.md5.digest(). - - namestring = "Passthrough" + hash_args() - - for message in stream_in: - assert_feed_unframe_protocol(message) - out_value = message - assert_transform_protocol(out_value) - yield (namestring, out_value) + def update(self, event): + assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event + assert event.has_key('sid'), "No sid in Passthrough: %s" % event + assert event.has_key('dt'), "No dt in Passthorughz: %s" % event + return event def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ @@ -75,16 +72,6 @@ def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs): assert_transform_protocol(out_value) yield (namestring, out_value) -def MovingAverageTransformGen(stream_in, days, fields): - """ - Generator that uses the MovingAverage state class to calculate - a moving average for all stocks over a specified number of days. - """ - return StatefulTransformGen(stream_in, - MovingAverage, - timedelta(days=days), - fields) - class MovingAverage(object): """ Class that maintains a dictionary from sids to EventWindows @@ -112,8 +99,9 @@ class MovingAverage(object): assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event assert event.has_key('sid'), "No sid in MovingAverage: %s" % event + assert event.has_key('dt'), "No dt in MovingAverage: %s" % event - output = ndict({'sid': event.sid}) + output = ndict({'sid': event.sid, 'dt': event.dt}) # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 6150757b..c10cffdf 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -96,3 +96,8 @@ def assert_transform_protocol(event): """Transforms should return an ndict to be merged by MergeGen.""" assert isinstance(event, ndict) +def assert_merge_protocol(tnfm_ids, message): + """Merge should output an ndict with a field for each id in its transform set.""" + assert isinstance(message, ndict) + assert set(tnfm_ids) == set(message.keys()) + diff --git a/zipline/gens/zmq_gens.py b/zipline/gens/zmq_gens.py new file mode 100644 index 00000000..524852a7 --- /dev/null +++ b/zipline/gens/zmq_gens.py @@ -0,0 +1,16 @@ +import zmq + +import zipline.protocol as zp + +def gen_from_zmq(poller, unframe): + """ + A generator that takes an initialized zmq poller and yields + messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE. + """ + while True: + message = poller.recv() + if message = zp.CONTROL_PROTOCOL.DONE: + yield "DONE" + break + else: + yield unframe(message)