From 9f7293e2d200bfa4d555e8995d340c6d64d4902f Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Wed, 1 Aug 2012 17:19:08 -0400 Subject: [PATCH] pipeline through merge --- zipline/gens/composites.py | 46 ++++++++++------------ zipline/gens/examples.py | 81 +++++++++++++++++++++++--------------- zipline/gens/merge.py | 39 +++++++++++------- zipline/gens/tradegens.py | 6 +-- zipline/gens/transform.py | 49 ++++++++++++++++------- zipline/gens/utils.py | 19 ++++++--- 6 files changed, 145 insertions(+), 95 deletions(-) diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 66697fa7..234db714 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -8,8 +8,8 @@ from zipline.gens.sort import date_sort from zipline.gens.merge import merge from zipline.gens.transform import stateful_transform -SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs']) -MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs']) +SourceBundle = namedtuple("SourceBundle", ['source', 'args', 'kwargs']) +TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs']) def date_sorted_sources(bundles): """ @@ -18,19 +18,19 @@ def date_sorted_sources(bundles): """ assert isinstance(bundles, (list, tuple)) for bundle in bundles: - assert isinstance(bundle, SortBundle) + assert isinstance(bundle, SourceBundle) # Calculate namestring hashes to pass to date_sort. names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs) for bundle in bundles] # Pass each source its arguments. - initialized = [bundle.source(*bundle.args, **bundle.kwargs) + source_gens = [bundle.source(*bundle.args, **bundle.kwargs) for bundle in bundles] # Convert the list of generators into a flat stream by pulling # one element at a time from each. - stream_in = roundrobin(initialized, names) + stream_in = roundrobin(source_gens, names) # Guarantee the flat stream will be sorted by date, using source_id as # tie-breaker, which is fully deterministic (given deterministic string @@ -38,7 +38,7 @@ def date_sorted_sources(bundles): return date_sort(stream_in, names) -def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs): +def merged_transforms(sorted_stream, bundles): """ A generator that takes the expected output of a date_sort, pipes it through a given set of transforms, and runs the results throught a @@ -48,36 +48,30 @@ def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs): tnfm_kwargs should be a list of dictionaries representing keyword arguments to each transform. """ - - # We should have as many sets of args as we have transforms. - assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs) + # Generate expected hashes for each transform + namestrings = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs) + for bundle in bundles] # Create a copy of the stream for each transform. - split = tee(sorted_stream, len(tnfms)) + split = tee(sorted_stream, len(bundles)) + # Package a stream copy with each bundle + tnfms_with_streams = zip(split, bundles) - # Package each transform with a stream copy and set of args. Use a list - # so that we can re-use this for calculating hashes. - bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs)) - - bundles = tuple(bundle_gen) - # list comprehension to create transform generators from - # bundles + # Convert the copies into transform streams. tnfm_gens = [ stateful_transform( - bundle.stream, + stream_copy, bundle.tnfm, *bundle.args, **bundle.kwargs ) - for bundle in bundles] - - # Generate expected hashes for each transform - hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs) - for bundle in bundles] + for stream_copy, bundle in tnfms_with_streams + ] # Roundrobin the outputs of our transforms to create a single flat stream. - to_merge = roundrobin(*tnfm_gens) + to_merge = roundrobin(tnfm_gens, namestrings) # Pipe the stream into merge. - merged = merge(to_merge, hashes) - return merged_transforms + merged = merge(to_merge, namestrings) + # Return the merged events. + return merged diff --git a/zipline/gens/examples.py b/zipline/gens/examples.py index d9051b10..fb3d8827 100644 --- a/zipline/gens/examples.py +++ b/zipline/gens/examples.py @@ -1,38 +1,57 @@ -from zipline.gens.composites import +from datetime import datetime, timedelta + +from zipline.utils.factory import create_trading_environment +from zipline.test_algorithms import TestAlgorithm + +from zipline.gens.composites import SourceBundle, TransformBundle, date_sorted_sources, merged_transforms +from zipline.gens.tradegens import SpecificEquityTrades +from zipline.gens.transform import MovingAverage, Passthrough if __name__ == "__main__": filter = [1,2,3,4] - #Set up source a. One hour between events. + #Set up source a. One minute between events. args_a = tuple() - kwargs_a = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = ), - 'filter' : filter - } - #Set up source b. One day between events. - args_b = tuple() - kwargs_b = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(days = 1), - 'filter' : filter - } - #Set up source c. One minute between events. - args_c = tuple() - kwargs_c = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), - 'filter' : filter - } + kwargs_a = { + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a) - sources = (SpecificEquityTrades,) * 4 - source_args = (args_a, args_b, args_c, args_d) - source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) - - # Generate our expected source_ids. - zip_args = zip(source_args, source_kwargs) - expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) - for args, kwargs in zip_args] - + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'sids' : [2,3], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 2), + 'filter' : filter + } + bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b) + + #Set up source c. Three minutes between events. + args_c = tuple() + kwargs_c = { + 'sids' : [3,4], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 3), + 'filter' : filter + } + bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c) + + source_bundles = (bundle_a, bundle_b, bundle_c) # Pipe our sources into sort. - sort_out = date_sorted_sources(sources, source_args, source_kwargs) + sort_out = date_sorted_sources(source_bundles) + + passthrough = TransformBundle(Passthrough, (), {}) + mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price', 'volume']), {}) + tnfm_bundles = (passthrough, mavg_price) + + merge_out = merged_transforms(sort_out, tnfm_bundles) + + for message in merge_out: + print "Event: \n", message.event + print "Transforms: \n", message.tnfms + + diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 4778ed5b..0e8fab93 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -7,7 +7,7 @@ from collections import deque from zipline import ndict from zipline.gens.utils import hash_args, \ assert_merge_protocol - +from itertools import repeat def merge(stream_in, tnfm_ids): """ @@ -17,7 +17,7 @@ def merge(stream_in, tnfm_ids): and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. """ - + assert isinstance(tnfm_ids, list) # Set up an internal queue for each expected source. @@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids): # Process incoming streams. for message in stream_in: - assert isinstance(message, tuple), \ - "Bad message in merge: %s" %message - assert len(message) == 2 - id, value = message + assert isinstance(message, ndict) + assert message.has_key('tnfm_id') + assert message.has_key('tnfm_value') + assert message.has_key('dt') + + id = message.tnfm_id assert id in tnfm_ids, \ "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) - assert isinstance(value, ndict), "Bad message in merge: %s" %message - - tnfms[id].append(value) + + tnfms[id].append(message) # Only pop messages when we have a pending message from # all datasources. Stop if all sources have signalled done. while ready(tnfms) and not done(tnfms): message = merge_one(tnfms) - assert_merge_protocol(tnfm_ids, message) yield message # We should have only a done message left in each queue. @@ -53,11 +53,22 @@ def merge(stream_in, tnfm_ids): "Bad last message in merge on exit: %s" % queue def merge_one(sources): - output = ndict() + dict_primer = zip(sources.keys(), repeat(None)) + transforms = ndict(dict_primer) + event_fields = ndict() + for key, queue in sources.iteritems(): - new_xform = ndict({key: queue.popleft()}) - output.merge(new_xform) - return output + + # Add transform value to the transforms dict. + message = queue.popleft() + transforms[message.tnfm_id] = message.tnfm_value + del message['tnfm_id'] + del message['tnfm_value'] + + # Merge any remaining fields into the event dict. + event_fields.merge(message) + + return ndict({'event' : event_fields, 'tnfms' : transforms}) #TODO: This is replicated in sort. Probably should be one source file. diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index a24cbe58..e3b88a0e 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -106,10 +106,8 @@ def SpecificEquityTrades(*args, **config): else: filtered = unfiltered - # Add a done message to the end of the stream. For a live - # datasource this would be handled by the containing Component. - out = chain(filtered, [mock_done(namestring)]) - return out + # Return the filtered event stream. + return filtered def RandomEquityTrades(*args, **config): # We shouldn't get any positional args. diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 64c817ca..ece2d383 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -3,6 +3,7 @@ Generator versions of transforms. """ import types +from copy import deepcopy from datetime import datetime from collections import deque, defaultdict from numbers import Number @@ -12,6 +13,7 @@ from zipline.gens.utils import assert_sort_unframe_protocol, \ assert_transform_protocol, hash_args class Passthrough(object): + FORWARDER = True """ Trivial class for forwarding events. """ @@ -19,10 +21,7 @@ class Passthrough(object): pass def update(self, event): - assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event - assert event.has_key('sid'), "No sid in Passthrough: %s" % event - assert event.has_key('dt'), "No dt in Passthorughz: %s" % event - return event + pass def functional_transform(stream_in, func, *args, **kwargs): """ @@ -44,9 +43,13 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs): """ Generic transform generator that takes each message from an in-stream and passes it to a state class. For each call to update, the state - class must produce a message to be fed downstream. + class must produce a message to be fed downstream. Any transform class + with the FORWARDER class variable set to true will forward all fields + in the original message. Otherwise only dt, tnfm_id, and tnfm_value + are forwarded. """ - + forward_all_fields = tnfm_class.__dict__.get('FORWARDER', False) + assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \ "Stateful transform requires a class." assert tnfm_class.__dict__.has_key('update'), \ @@ -58,11 +61,31 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs): # Generate the string associated with this generator's output. namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) + # IMPORTANT: Messages may contain pointers that are shared with + # other streams, so we only manipulate copies. for message in stream_in: + assert_sort_unframe_protocol(message) - out_value = state.update(message) - assert_transform_protocol(out_value) - yield (namestring, out_value) + message_copy = deepcopy(message) + + # Same shared pointer issue here as above. + tnfm_value = state.update(deepcopy(message_copy)) + + # If we want to keep all original values, just append tnfm_id + # and tnfm_value. + if forward_all_fields: + out_message = message_copy + out_message.tnfm_id = namestring + out_message.tnfm_value = tnfm_value + yield out_message + + # Otherwise send tnfm_id, tnfm_value, and the message date. + else: + out_message = ndict() + out_message.tnfm_id = namestring + out_message.tnfm_value = tnfm_value + out_message.dt = message_copy.dt + yield out_message class MovingAverage(object): """ @@ -70,6 +93,7 @@ class MovingAverage(object): Upon receipt of each message we update the corresponding window and return the calculated average. """ + FORWARDER = False def __init__(self, delta, fields): self.delta = delta @@ -93,16 +117,11 @@ class MovingAverage(object): assert event.has_key('sid'), "No sid in MovingAverage: %s" % event assert event.has_key('dt'), "No dt in MovingAverage: %s" % event - output = ndict({'sid': event.sid, 'dt': event.dt}) # This will create a new EventWindow if this is the first # message for this sid. window = self.sid_windows[event.sid] window.update(event) - averages = window.get_averages() - - # Return the calculated averages along with - output.merge(averages) - return output + return window.get_averages() class EventWindow(object): """ diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 209c98b0..d51452bf 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -1,6 +1,7 @@ import pytz import numbers +from collections import OrderedDict from hashlib import md5 from datetime import datetime from itertools import izip_longest @@ -16,8 +17,16 @@ def mock_raw_event(sid, dt): } return event -def mock_done(source_id): - return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0}) +def mock_done(id): + return ndict({ + 'dt' : "DONE", + "source_id" : id, + 'tnfm_id' : id, + 'tnfm_value': None, + 'type' : 0 + }) + +done_message = mock_done def alternate(g1, g2): """Specialized version of roundrobin for just 2 generators.""" @@ -36,14 +45,14 @@ def roundrobin(sources, namestrings): mapping = OrderedDict(zip(namestrings, sources)) # While our generators have not been exhausted, pull elements - while mapping != []: - for namestring, source in mapping: + while mapping.keys() != []: + for namestring, source in mapping.iteritems(): try: message = source.next() yield message except StopIteration: yield done_message(namestring) - del mapping(namestring) + del mapping[namestring]