diff --git a/zipline/finance/sources.py b/zipline/finance/sources.py index 9cde473c..b9f818ea 100644 --- a/zipline/finance/sources.py +++ b/zipline/finance/sources.py @@ -83,7 +83,7 @@ class RandomEquityTrades(TradeDataSource): class SpecificEquityTrades(TradeDataSource): """ - Generates a random stream of trades for testing. + Generates a non-random stream of trades for testing. """ def init(self, event_list): diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 1b710212..5ee45de6 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,17 +1,80 @@ +import datetime +from itertools import tee +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.utils import roundrobin, hash_args from zipline.gens.feed import FeedGen -from zipline.gens.tradegen import SpecificEquityTrades -from zipline.gens.transform +from zipline.gens.merge import MergeGen +from zipline.gens.transform import StatefulTransformGen - - - -def PreTransformLayer(sources): - """A generator that takes a list of sources and runs their output - through a FeedGen.""" - not_finished = len_ - - while not_finished: - - +def PreTransformLayer(sources, source_args, source_kwargs): + """ + Takes a list of generator functions, a list of tuples of positional arguments, + and a list of dictionaries of keyword arguments. Packages up all arguments + and passes them into a FeedGen. + """ + assert len(sources) == len(source_args) == len(source_kwargs) + # Package up sources and arguments. + arg_bundles = zip(sources, source_args, source_kwargs) + + # Calculate namestring hashes to pass to FeedGen. + namestrings = [source.__name__ + hash_args(*args, **kwargs) + for source, args, kwargs in arg_bundles] + # Pass each source its arguments. + initialized = tuple(source(*args, **kwargs) + for source, args, kwargs in arg_bundles) + + stream_in = roundrobin(*initialized) + return FeedGen(stream_in, namestrings) + + +def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): + """ + A generator that takes the expected output of a FeedGen, pipes it + through a given set of transforms, and runs the results throught a + MergeGen to output a unified stream. tnfms should be a list of + pointers to generator functions. tnfm_args should be a list of + tuples, representing the arguments to be passed to each transform. + tnfm_kwargs should be a list of dictionaries representing keyword + arguments to each transform. + """ + + # We should have as many sets of args as we have transforms. + assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) + + # Create a copy of the stream for each transform. + split = tee(feed_stream, len(tnfms)) + + # Package each stream copy with a transform and set of args. Use a list + # so that we can re-use this for calculating hashes. + bundles = zip(split, tnfms, tnfm_args, tnfm_kwargs) + + tnfm_gens = [StatefulTransformGen(stream, tnfm, *args, **kwargs) + for stream, tnfm, args, kwargs in bundles] + + # Generate expected hashes for each transform + hashes = [tnfm.__name__ + hash_args(*args, **kwargs) + for _, tnfm, args, kwargs in bundles] + + # Roundrobin the outputs of our transforms to create a single flat stream. + to_merge = roundrobin(*tnfm_gens) + + # Pipe the stream into MergeGen. + merged = MergeGen(to_merge, hashes) + return merged + +if __name__ == "__main__": + from zipline.gens.transform import MovingAverage, Passthrough + + import nose.tools; nose.tools.set_trace() + source = SpecificEquityTrades + feed_out = PreTransformLayer((source,), ((),), ({},)) + + transforms = [MovingAverage, Passthrough] + args = [(datetime.timedelta(days = 1), ['price']), ()] + kwargs = [{}, {}] + + tlayer = TransformLayer(feed_out, transforms, args, kwargs) + + diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index e7498d5c..f029d0ee 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -13,7 +13,8 @@ from collections import deque, defaultdict from zipline import ndict from zipline.gens.utils import hash_args, assert_datasource_protocol, \ - assert_trade_protocol, assert_datasource_unframe_protocol + assert_trade_protocol, assert_datasource_unframe_protocol, \ + assert_feed_protocol import zipline.protocol as zp @@ -38,7 +39,7 @@ def FeedGen(stream_in, source_ids): # Incoming messages should be the output of DATASOURCE_UNFRAME. assert_datasource_unframe_protocol(message), \ "Bad message in FeedGen: %s" % message - + # Only allow messages from sources we expect. assert message.source_id in sources, "Unexpected source: %s" % message @@ -49,9 +50,9 @@ def FeedGen(stream_in, source_ids): while full(sources) and not done(sources): message = pop_oldest(sources) - assert feed_protocol(message) + assert_feed_protocol(message) yield message - + # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue @@ -60,8 +61,9 @@ def FeedGen(stream_in, source_ids): def full(sources): """ - Feed is full when every internal queue has at least one message. Note that - this include DONE messages, so done(sources) is True only if full(sources). + Feed is full when every internal queue has at least one + message. Note that this include DONE messages, so done(sources) is + True only if full(sources). """ assert isinstance(sources, dict) return all( (queue_is_full(source) for source in sources.itervalues()) ) diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 7c0a195c..52adcb84 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -13,7 +13,7 @@ from collections import deque, defaultdict from zipline import ndict from zipline.gens.utils import hash_args, assert_datasource_protocol, \ - assert_trade_protocol, assert_datasource_unframe_protocol + assert_trade_protocol, assert_datasource_unframe_protocol, assert_merge_protocol import zipline.protocol as zp @@ -25,44 +25,46 @@ def MergeGen(stream_in, tnfm_ids): and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. """ - - assert isinstance(source_ids, list) + assert isinstance(tnfm_ids, list) + # Set up an internal queue for each expected source. - sources = {} - for id in source_ids: - assert isinstance(id, basestring), "Bad source_id %s" % source_id - sources[id] = deque() + tnfms = {} + for id in tnfm_ids: + assert isinstance(id, basestring), "Bad source_id %s" % id + tnfms[id] = deque() # Process incoming streams. for message in stream_in: - assert isinstance(message, ndict), \ + assert isinstance(message, tuple), \ "Bad message in MergeGen: %s" %message - assert message.tnfm_id in tnfm_ids, \ - "Message from unexpected tnfm: %s, %s" % (message, tnfm_ids) + assert len(message) == 2 + id, value = message + assert id in tnfm_ids, \ + "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) + assert isinstance(value, ndict), "Bad message in MergeGen: %s" %message - assert message.has_key('value') - - source[message.tnfm_id].append(message) + tnfms[id].append(value) # Only pop messages when we have a pending message from # all datasources. Stop if all sources have signalled done. - while full(sources) and not done(sources): - message = merge_one(sources) - assert merge_protocol(message) + while full(tnfms) and not done(tnfms): + message = merge_one(tnfms) + assert_merge_protocol(tnfm_ids, message) yield message # We should have only a done message left in each queue. - for queue in sources.itervalues(): + for queue in tnfms.itervalues(): assert len(queue) == 1, "Bad queue in MergeGen on exit: %s" % queue assert queue[0].dt == "DONE", \ "Bad last message in MergeGen on exit: %s" % queue def merge_one(sources): output = ndict() - for queue in sources.itervalues(): - output.merge(queue.popleft()) + for key, queue in sources.iteritems(): + new_xform = ndict({key: queue.popleft()}) + output.merge(new_xform) return output diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 5e632f9d..b5007b91 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -13,8 +13,10 @@ from collections import deque from zipline import ndict from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\ pop_oldest -from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ - assert_trade_protocol, date_gen, alternate +from zipline.gens.utils import hash_args, assert_datasource_protocol,\ + assert_trade_protocol, alternate +from zipline.gens.tradegens import date_gen, SpecificEquityTrades +from zipline.gens.composites import PreTransformLayer import zipline.protocol as zp @@ -98,12 +100,11 @@ class FeedGenTestCase(TestCase): l = list(feed_gen) assert l == expected - def test_single_source(self): source_ids = ['a'] # 100 events, increasing by a minute at a time. type = zp.DATASOURCE_TYPE.TRADE - dates = list(date_gen(n = 1)) + dates = list(date_gen(count = 100)) dates.append("DONE") # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] @@ -125,7 +126,7 @@ class FeedGenTestCase(TestCase): # Set up source 'a'. Outputs 20 events with 2 minute deltas. delta_a = timedelta(minutes = 2) - dates_a = list(date_gen(delta = delta_a, n = 20)) + dates_a = list(date_gen(delta = delta_a, count = 20)) dates_a.append("DONE") events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) @@ -133,7 +134,7 @@ class FeedGenTestCase(TestCase): # Set up source 'b'. Outputs 10 events with 1 minute deltas. delta_b = timedelta(minutes = 1) - dates_b = list(date_gen(delta = delta_b, n = 10)) + dates_b = list(date_gen(delta = delta_b, count = 10)) dates_b.append("DONE") events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) @@ -152,13 +153,65 @@ class FeedGenTestCase(TestCase): sequential = chain(iter(events_a), iter(events_b)) self.run_FeedGen(sequential, expected, source_ids) + + def test_full_feed_layer(self): + + filter = [1,2] + #Set up source a. One hour between events. + args_a = tuple() + kwargs_a = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(hours = 1), + 'filter' : filter + } + #Set up source b. One day between events. + args_b = tuple() + kwargs_b = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(days = 1), + 'filter' : filter + } + #Set up source c. One minute between events. + args_c = tuple() + kwargs_c = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + # Set up source d. This should produce no events because the + # internal sids don't match the filter. + args_d = tuple() + kwargs_d = {'sids' : [3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + + sources = (SpecificEquityTrades,) * 4 + source_args = (args_a, args_b, args_c, args_d) + source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) + + # Generate our expected source_ids. + zip_args = zip(source_args, source_kwargs) + expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) + for args, kwargs in zip_args] + + # Pipe our sources into feed. + feed_out = PreTransformLayer(sources, source_args, source_kwargs) + + # Read all the values from feed and assert that they arrive in + # the correct sorting with the expected hash values. + to_list = list(feed_out) + copy = to_list[:] + for e in to_list: + # All events should match one of our expected source_ids. + assert e.source_id in expected_ids + # But none of them should match source_d. + assert e.source_id != hash_args(*args_d, **kwargs_d) + + expected = sorted(copy, compare_by_dt_source_id) + assert to_list == expected - def test_with_specific_equity(self): - - - - - def mock_data_unframe(source_id, dt, type): event = ndict() event.source_id = source_id @@ -182,7 +235,3 @@ def compare_by_dt_source_id(x,y): else: return 0 - - - - diff --git a/zipline/gens/test_mongods.py b/zipline/gens/test_mongods.py index ea19d90e..d9b8dbe5 100644 --- a/zipline/gens/test_mongods.py +++ b/zipline/gens/test_mongods.py @@ -10,7 +10,7 @@ from itertools import izip, izip_longest from datetime import datetime, timedelta from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen -from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ +from zipline.gens.utils import hash_args, assert_datasource_protocol,\ assert_trade_protocol, mock_raw_event import zipline.protocol as zp @@ -107,7 +107,7 @@ class TestMongoDataGenerator(TestCase): for field in iter(['sid', 'dt', 'price', 'volume']): assert db[field] == expected[field] - # Expected output of stringify_args: + # Expected output of hash_args: assert db['source_id'] == \ 'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc' diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 24ece45b..fb0b3f48 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -1,44 +1,126 @@ import random -from itertools import chain, repeat, cycle, ifilter +from itertools import chain, repeat, cycle, ifilter, izip from datetime import datetime, timedelta -from zipline.utils.factory import create_trade, create_trade -from zipline.gens.utils import date_gen +from zipline.utils.factory import create_trade +from zipline.gens.utils import hash_args, mock_done -def mock_prices(n, rand = False): - """Utility to generate a set of prices. By default - cycles through values from 0.0 to 10.0 n times. Optional - flag to give random values between 0.0 and 10.0""" +def date_gen(start = datetime(2012, 6, 6, 0), + delta = timedelta(minutes = 1), + count = 100): + """ + Utility to generate a stream of dates. + """ + return (start + (i * delta) for i in xrange(count)) + +def mock_prices(count, rand = False): + """ + Utility to generate a stream of mock prices. By default + cycles through values from 0.0 to 10.0, n times. Optional + flag to give random values between 0.0 and 10.0 + """ if rand: - return (random.uniform(0.0, 10.0) for i in xrange(n)) + return (random.uniform(0.0, 10.0) for i in xrange(count)) else: - return (float(i % 11) for i in xrange(1,n+1)) + return (float(i % 11) for i in xrange(1,count+1)) -def mock_volumes(n, rand = False): - """Does the same as mock_prices. Different function name - for readability.""" - return mock_prices(n, rand) - -def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = None): - """Returns the first n events of event_list if specified. - Otherwise generates a sensible stream of events.""" +def mock_volumes(count, rand = False): + """ + Utility to generate a set of volumes. By default cycles + through values from 100 to 1000, incrementing by 50. Optional + flag to give random values between 100 and 1000. + """ + if rand: + return (random.randrange(100, 1000) for i in xrange(count)) + else: + return ((i * 50)%900 + 100 for i in xrange(count)) +def fuzzy_dates(count = 500): + """ + Add +-10 seconds to each event from a date_gen. Note that this + still guarantees sorting, since the default on date_gen is minute + separation of events. + """ + for date in date_gen(count = count): + yield date + timedelta(seconds = random.randint(-10, 10)) + +def SpecificEquityTrades(*args, **config): + """ + Yields all events in event_list that match the given sid_filter. + If no event_list is specified, generates an internal stream of events + to filter. Returns all events if filter is None. + """ + # We shouldn't get any positional arguments. + assert args == () + + # Unpack config dictionary with default values. + count = config.get('count', 500) + sids = config.get('sids', [1, 2]) + start = config.get('start', datetime(2012, 6, 6, 0)) + delta = config.get('delta', timedelta(minutes = 1)) + + # Default to None for event_list and filter. + event_list = config.get('event_list') + filter = config.get('filter') + + arg_string = hash_args(*args, **config) + namestring = "SpecificEquityTrades" + arg_string + # If we have an event_list, ignore the other arguments and use the list. + # TODO: still append our namestring? if event_list: unfiltered = (event for event in event_list) - + + # Set up iterators for each expected field. else: - dates = date_gen(n = n) - prices = mock_prices(n) - volumes = mock_volumes(n) - sids = cycle(iter(sids)) - + dates = date_gen(count = count, start = start, delta = delta) + prices = mock_prices(count) + volumes = mock_volumes(count) + sids = cycle(sids) + + # Combine the iterators into a single iterator of arguments arg_gen = izip(sids, prices, volumes, dates) - - unfiltered = (create_trade(*args) for args in arg_gen) + + # Convert argument packages into events. + unfiltered = (create_trade(*args, source_id = namestring) + for args in arg_gen) + + # If we specified a sid filter, filter out elements that don't match the filter. if filter: - filtered = ifilter(lambda event: event.sid in filter) + filtered = ifilter(lambda event: event.sid in filter, unfiltered) + + # Otherwise just use all events. else: filtered = unfiltered + # Add a done message to the end of the stream. For a live + # datasource this would be handled by the containing Component. + out = chain(filtered, [mock_done(namestring)]) + return out + +def RandomEquityTrades(*args, **config): + # We shouldn't get any positional args. + assert args == () + + count = config.get('count', 500) + sids = config.get('sids', [1,2]) + filter = config.get('filter') + + dates = fuzzy_dates(count) + prices = mock_prices(count, rand = True) + volumes = mock_volumes(count, rand = True) + sids = cycle(sids) + + arg_gen = izip(sids, prices, volumes, dates) + + unfiltered = (create_trade(*args) for args in arg_gen) + + if filter: + filtered = ifilter(lambda event: event.sid in filter, unfiltered) + else: + filtered = unfiltered return filtered + +# if __name__ == "__main__": +# import nose.tools; nose.tools.set_trace() +# trades = SpecificEquityTrades(filter = [1]) diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index bdd85ee9..137b9d69 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -14,26 +14,24 @@ from numbers import Number from itertools import izip from zipline import ndict -from zipline.gens.utils import hash_args, date_gen -from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol +from zipline.gens.tradegens import date_gen +from zipline.gens.utils import assert_feed_unframe_protocol, \ + assert_transform_protocol, hash_args import zipline.protocol as zp -def PassthroughTransformGen(stream_in): - """Trivial transform for event forwarding.""" +class Passthrough(object): + """ + Trivial function for forwarding events. + """ + def __init__(self): + pass - # hash_args with no arguments is the same as: - # hasher = hashlib.md5() - # hasher.update(":"); - # hashlib.md5.digest(). - - namestring = "Passthrough" + hash_args() - - for message in stream_in: - assert_feed_unframe_protocol(message) - out_value = message - assert_transform_protocol(out_value) - yield (namestring, out_value) + def update(self, event): + assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event + assert event.has_key('sid'), "No sid in Passthrough: %s" % event + assert event.has_key('dt'), "No dt in Passthorughz: %s" % event + return event def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ @@ -43,6 +41,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ # TODO: Distinguish between functions and classes in hash_args. + # As implemented we will get assertion errors if a function and + # stateful class have the same name, which may or may not be + # what we want. + namestring = fun.__name__ + hash_args(*args, **kwargs) for message in stream_in: @@ -70,13 +72,6 @@ def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs): assert_transform_protocol(out_value) yield (namestring, out_value) -def MovingAverageTransformGen(stream_in, days, fields): - """ - Generator that uses the MovingAverage state class to calculate - a moving average for all stocks over a specified number of days. - """ - return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields) - class MovingAverage(object): """ Class that maintains a dictionary from sids to EventWindows @@ -91,7 +86,7 @@ class MovingAverage(object): # No way to pass arguments to the defaultdict factory, so we # need to define a method to generate the correct EventWindows. self.sid_windows = defaultdict(self.create_window) - + def create_window(self): """Factory method for self.sid_windows.""" return EventWindow(self.delta, self.fields) @@ -104,13 +99,18 @@ class MovingAverage(object): assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event assert event.has_key('sid'), "No sid in MovingAverage: %s" % event + assert event.has_key('dt'), "No dt in MovingAverage: %s" % event + output = ndict({'sid': event.sid, 'dt': event.dt}) # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) + averages = window.get_averages() - return window.get_averages() + # Return the calculated averages along with + output.merge(averages) + return output class EventWindow(object): """ @@ -144,7 +144,7 @@ class EventWindow(object): # newest oldest # | | # V V - + while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: # popleft removes and returns ticks[0] popped = self.ticks.popleft() @@ -168,6 +168,7 @@ class EventWindow(object): Return an ndict of all our tracked averages. """ out = ndict() + # out.ticks = len(self.ticks) for field in self.fields: out[field] = self.average(field) return out @@ -186,26 +187,26 @@ class EventWindow(object): assert isinstance(event[field], Number), \ "Got %s for %s in EventWindow" % (event[field], field) -if __name__ == "__main__": +# if __name__ == "__main__": - def make_event(**kwargs): - e = ndict() - for key, value in kwargs.iteritems(): - e[key] = value - return e +# def make_event(**kwargs): +# e = ndict() +# for key, value in kwargs.iteritems(): +# e[key] = value +# return e - dates = date_gen(delta = timedelta(hours = 12)) - events = ( - make_event( - sid = 'foo', price = random.random(), - dt = date, - type = zp.DATASOURCE_TYPE.TRADE, - source_id = 'ds', - vol = i - ) - for date, i in izip(dates, xrange(100)) - ) +# dates = date_gen(delta = timedelta(hours = 12)) +# events = ( +# make_event( +# sid = 'foo', price = random.random(), +# dt = date, +# type = zp.DATASOURCE_TYPE.TRADE, +# source_id = 'ds', +# vol = i +# ) +# for date, i in izip(dates, xrange(100)) +# ) - gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) +# gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index efe81a14..c10cffdf 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -16,16 +16,28 @@ def mock_raw_event(sid, dt): } return event -def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100): - return (start + i * delta for i in xrange(n)) - +def mock_done(source_id): + return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0}) + def alternate(g1, g2): + """Specialized version of roundrobin for just 2 generators.""" for e1, e2 in izip_longest(g1, g2): if e1 != None: yield e1 if e2 != None: yield e2 +def roundrobin(*args): + """ + Takes N generators, pulling one element off each until all inputs + are empty. + """ + for elem_tuple in izip_longest(*args): + for value in elem_tuple: + if value != None: + yield value + + def hash_args(*args, **kwargs): """Define a unique string for any set of representable args.""" arg_string = '_'.join([str(arg) for arg in args]) @@ -73,14 +85,19 @@ def assert_feed_protocol(event): assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_feed_unframe_protocol(event): """Same as above.""" assert isinstance(event, ndict) + assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_transform_protocol(event): - pass + """Transforms should return an ndict to be merged by MergeGen.""" + assert isinstance(event, ndict) + +def assert_merge_protocol(tnfm_ids, message): + """Merge should output an ndict with a field for each id in its transform set.""" + assert isinstance(message, ndict) + assert set(tnfm_ids) == set(message.keys()) diff --git a/zipline/gens/zmq_gens.py b/zipline/gens/zmq_gens.py new file mode 100644 index 00000000..524852a7 --- /dev/null +++ b/zipline/gens/zmq_gens.py @@ -0,0 +1,16 @@ +import zmq + +import zipline.protocol as zp + +def gen_from_zmq(poller, unframe): + """ + A generator that takes an initialized zmq poller and yields + messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE. + """ + while True: + message = poller.recv() + if message = zp.CONTROL_PROTOCOL.DONE: + yield "DONE" + break + else: + yield unframe(message) diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 23c2eb3e..db440891 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -69,9 +69,9 @@ def create_trading_environment(year=2006): return trading_environment -def create_trade(sid, price, amount, datetime): +def create_trade(sid, price, amount, datetime, source_id = "test_factory"): row = zp.ndict({ - 'source_id' : "test_factory", + 'source_id' : source_id, 'type' : zp.DATASOURCE_TYPE.TRADE, 'sid' : sid, 'dt' : datetime,