diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 1f68d58e..ba0538b6 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -1,18 +1,62 @@ -from zipline.gens.utils import roundrobin +from itertools import tee + +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.utils import roundrobin, hash_args from zipline.gens.feed import FeedGen +from zipline.gens.merge import MergeGen +def PreTransformLayer(sources, source_args, source_kwargs): + """ + Takes a list of generator functions, a list of tuples of positional arguments, + and a list of dictionaries of keyword arguments. Packages up all arguments + and passes them into a FeedGen. + """ + assert len(sources) == len(source_args) == len(source_kwargs) + # Package up sources and arguments. + arg_bundles = zip(sources, source_args, source_kwargs) + # Calculate namestring hashes to pass to FeedGen. + namestrings = [source.__name__ + hash_args(*args, **kwargs) + for source, args, kwargs in arg_bundles] + # Pass each source its arguments. + initialized = tuple(source(*args, **kwargs) + for source, args, kwargs in arg_bundles) -def PreTransformLayer(sources, source_ids): - """ - A generator that takes a tuple of sources and a list ids, piping - their output into a feed_gen. - """ - stream_in = roundrobin(*sources) + stream_in = roundrobin(*initialized) return FeedGen(stream_in, source_ids) -def TransformLayer(feed_stream, tnfms): - """ """ - pass +def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs): + """ + A generator that takes the expected output of a FeedGen, pipes it + through a given set of transforms, and runs the results throught a + MergeGen to output a unified stream. tnfms should be a list of + pointers to generator functions. tnfm_args should be a list of + tuples, representing the arguments to be passed to each transform. + tnfm_kwargs should be a list of dictionaries representing keyword + arguments to each transform. + """ + # We should have as many sets of args as we have transforms. + assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) + # Create a copy of the stream for each transform. + split = tee(feed_stream, len(tnfms)) + + # Package each stream copy with a transform and set of args. Use a list + # so that we can re-use this for calculating hashes. + bundles = zip(split, iter(tnfms), iter(tnfm_args), iter(tnfm_kwargs)) + + # Convert the argument bundles into a tuple of transform objects. + transformed = tuple((tnfm(stream, *args, **kwargs) + for stream, tnfm, args, kwargs in iter(bundles))) + + # Roundrobin the outputs of our transforms to create a single flat stream. + to_merge = roundrobin(*transformed) + + merged = MergeGen() + + +if __name__ == "__main__": + + source = SpecificEquityTrades() + diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py index 0528d7ee..284023a0 100644 --- a/zipline/gens/feed.py +++ b/zipline/gens/feed.py @@ -53,7 +53,6 @@ def FeedGen(stream_in, source_ids): assert_feed_protocol(message) yield message - import nose.tools; nose.tools.set_trace() # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py index 12cf4bc5..7c486e33 100644 --- a/zipline/gens/test_feed.py +++ b/zipline/gens/test_feed.py @@ -153,39 +153,41 @@ class FeedGenTestCase(TestCase): sequential = chain(iter(events_a), iter(events_b)) self.run_FeedGen(sequential, expected, source_ids) - + def test_full_feed_layer(self): filter = [1,2] + #Set up source a. + args_a = tuple() + kwargs_a = {'sids' : [1,2,3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + #Set up source b. + args_b = tuple() + kwargs_b = {'sids' : [1,2,3,5], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + #Set up source c. + args_c = tuple() + kwargs_c = {'sids' : [1,2,3,5], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } - source_a = SpecificEquityTrades(sids = [1,2,3,4], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) - id_a = "SpecificEquityTradesd175237b28d2f52df208c97cf4af896e" - - # Change the internal sid list to give us a different hash. - source_b = SpecificEquityTrades(sids = [1,2,3,5], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) + sources = tuple(SpecificEquityTrades) * 3 + source_args = (args_a, args_b, args_c) + source_kwargs = (kwargs_a, kwargs_b, kwargs_c) - id_b = 'SpecificEquityTrades2bf2c2d6d01d4dbfc0b2818438ea8151' - - # Change the internal sid list to give us a different hash. - source_c = SpecificEquityTrades(sids = [1,2,3,6], - start = datetime(2012,6,6,0), - delta = timedelta(minutes=1), - filter = filter - ) - id_c = 'SpecificEquityTrades16f7437db2d14e5373ef20025f49a3fe' - - sources = (source_a, source_b, source_c) - source_ids = [id_a, id_b, id_c] + feed_out = PreTransformLayer(sources, source_args, source_kwargs) + to_list = list(feed_out) + copy = to_list[:] + expected = sorted(copy, compare_by_dt_source_id) - feed_out = PreTransformLayer(sources, source_ids) - l = list(feed_out) + assert to_list == expected def mock_data_unframe(source_id, dt, type): event = ndict() @@ -210,7 +212,3 @@ def compare_by_dt_source_id(x,y): else: return 0 - - - - diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index bdd85ee9..666fdd46 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -14,8 +14,9 @@ from numbers import Number from itertools import izip from zipline import ndict -from zipline.gens.utils import hash_args, date_gen -from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol +from zipline.gens.tradegens import date_gen +from zipline.gens.utils import assert_feed_unframe_protocol, \ + assert_transform_protocol, hash_args import zipline.protocol as zp @@ -43,6 +44,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs): """ # TODO: Distinguish between functions and classes in hash_args. + # As implemented we will get assertion errors if a function and + # stateful class have the same name, which may or may not be + # what we want. + namestring = fun.__name__ + hash_args(*args, **kwargs) for message in stream_in: @@ -75,7 +80,10 @@ def MovingAverageTransformGen(stream_in, days, fields): Generator that uses the MovingAverage state class to calculate a moving average for all stocks over a specified number of days. """ - return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields) + return StatefulTransformGen(stream_in, + MovingAverage, + timedelta(days=days), + fields) class MovingAverage(object): """ @@ -91,7 +99,7 @@ class MovingAverage(object): # No way to pass arguments to the defaultdict factory, so we # need to define a method to generate the correct EventWindows. self.sid_windows = defaultdict(self.create_window) - + def create_window(self): """Factory method for self.sid_windows.""" return EventWindow(self.delta, self.fields) @@ -105,12 +113,16 @@ class MovingAverage(object): assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event assert event.has_key('sid'), "No sid in MovingAverage: %s" % event + output = ndict({'sid': event.sid}) # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) + averages = window.get_averages() - return window.get_averages() + # Return the calculated averages along with + output.merge(averages) + return output class EventWindow(object): """ @@ -144,7 +156,7 @@ class EventWindow(object): # newest oldest # | | # V V - + while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta: # popleft removes and returns ticks[0] popped = self.ticks.popleft() @@ -168,6 +180,7 @@ class EventWindow(object): Return an ndict of all our tracked averages. """ out = ndict() + # out.ticks = len(self.ticks) for field in self.fields: out[field] = self.average(field) return out @@ -186,26 +199,26 @@ class EventWindow(object): assert isinstance(event[field], Number), \ "Got %s for %s in EventWindow" % (event[field], field) -if __name__ == "__main__": +# if __name__ == "__main__": - def make_event(**kwargs): - e = ndict() - for key, value in kwargs.iteritems(): - e[key] = value - return e +# def make_event(**kwargs): +# e = ndict() +# for key, value in kwargs.iteritems(): +# e[key] = value +# return e - dates = date_gen(delta = timedelta(hours = 12)) - events = ( - make_event( - sid = 'foo', price = random.random(), - dt = date, - type = zp.DATASOURCE_TYPE.TRADE, - source_id = 'ds', - vol = i - ) - for date, i in izip(dates, xrange(100)) - ) +# dates = date_gen(delta = timedelta(hours = 12)) +# events = ( +# make_event( +# sid = 'foo', price = random.random(), +# dt = date, +# type = zp.DATASOURCE_TYPE.TRADE, +# source_id = 'ds', +# vol = i +# ) +# for date, i in izip(dates, xrange(100)) +# ) - gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) +# gen = MovingAverageTransformGen(events, 1, ['price', 'vol']) diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 8ed6ea9d..6150757b 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -85,14 +85,14 @@ def assert_feed_protocol(event): assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_feed_unframe_protocol(event): """Same as above.""" assert isinstance(event, ndict) + assert isinstance(event.source_id, basestring) assert event.type in DATASOURCE_TYPE assert event.has_key('dt') - def assert_transform_protocol(event): - pass + """Transforms should return an ndict to be merged by MergeGen.""" + assert isinstance(event, ndict)