pipeline through merge

This commit is contained in:
scottsanderson
2012-08-01 17:19:08 -04:00
parent 6de01a1c6e
commit 9f7293e2d2
6 changed files with 145 additions and 95 deletions
+20 -26
View File
@@ -8,8 +8,8 @@ from zipline.gens.sort import date_sort
from zipline.gens.merge import merge
from zipline.gens.transform import stateful_transform
SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs'])
MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs'])
SourceBundle = namedtuple("SourceBundle", ['source', 'args', 'kwargs'])
TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs'])
def date_sorted_sources(bundles):
"""
@@ -18,19 +18,19 @@ def date_sorted_sources(bundles):
"""
assert isinstance(bundles, (list, tuple))
for bundle in bundles:
assert isinstance(bundle, SortBundle)
assert isinstance(bundle, SourceBundle)
# Calculate namestring hashes to pass to date_sort.
names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
# Pass each source its arguments.
initialized = [bundle.source(*bundle.args, **bundle.kwargs)
source_gens = [bundle.source(*bundle.args, **bundle.kwargs)
for bundle in bundles]
# Convert the list of generators into a flat stream by pulling
# one element at a time from each.
stream_in = roundrobin(initialized, names)
stream_in = roundrobin(source_gens, names)
# Guarantee the flat stream will be sorted by date, using source_id as
# tie-breaker, which is fully deterministic (given deterministic string
@@ -38,7 +38,7 @@ def date_sorted_sources(bundles):
return date_sort(stream_in, names)
def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
def merged_transforms(sorted_stream, bundles):
"""
A generator that takes the expected output of a date_sort, pipes it
through a given set of transforms, and runs the results throught a
@@ -48,36 +48,30 @@ def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
"""
# We should have as many sets of args as we have transforms.
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
# Generate expected hashes for each transform
namestrings = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
# Create a copy of the stream for each transform.
split = tee(sorted_stream, len(tnfms))
split = tee(sorted_stream, len(bundles))
# Package a stream copy with each bundle
tnfms_with_streams = zip(split, bundles)
# Package each transform with a stream copy and set of args. Use a list
# so that we can re-use this for calculating hashes.
bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs))
bundles = tuple(bundle_gen)
# list comprehension to create transform generators from
# bundles
# Convert the copies into transform streams.
tnfm_gens = [
stateful_transform(
bundle.stream,
stream_copy,
bundle.tnfm,
*bundle.args,
**bundle.kwargs
)
for bundle in bundles]
# Generate expected hashes for each transform
hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
for bundle in bundles]
for stream_copy, bundle in tnfms_with_streams
]
# Roundrobin the outputs of our transforms to create a single flat stream.
to_merge = roundrobin(*tnfm_gens)
to_merge = roundrobin(tnfm_gens, namestrings)
# Pipe the stream into merge.
merged = merge(to_merge, hashes)
return merged_transforms
merged = merge(to_merge, namestrings)
# Return the merged events.
return merged
+50 -31
View File
@@ -1,38 +1,57 @@
from zipline.gens.composites import
from datetime import datetime, timedelta
from zipline.utils.factory import create_trading_environment
from zipline.test_algorithms import TestAlgorithm
from zipline.gens.composites import SourceBundle, TransformBundle, date_sorted_sources, merged_transforms
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.transform import MovingAverage, Passthrough
if __name__ == "__main__":
filter = [1,2,3,4]
#Set up source a. One hour between events.
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = ),
'filter' : filter
}
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
kwargs_a = {
'sids' : [1,2],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a)
sources = (SpecificEquityTrades,) * 4
source_args = (args_a, args_b, args_c, args_d)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
# Generate our expected source_ids.
zip_args = zip(source_args, source_kwargs)
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
for args, kwargs in zip_args]
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'sids' : [2,3],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 2),
'filter' : filter
}
bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b)
#Set up source c. Three minutes between events.
args_c = tuple()
kwargs_c = {
'sids' : [3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 3),
'filter' : filter
}
bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c)
source_bundles = (bundle_a, bundle_b, bundle_c)
# Pipe our sources into sort.
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
sort_out = date_sorted_sources(source_bundles)
passthrough = TransformBundle(Passthrough, (), {})
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price', 'volume']), {})
tnfm_bundles = (passthrough, mavg_price)
merge_out = merged_transforms(sort_out, tnfm_bundles)
for message in merge_out:
print "Event: \n", message.event
print "Transforms: \n", message.tnfms
+25 -14
View File
@@ -7,7 +7,7 @@ from collections import deque
from zipline import ndict
from zipline.gens.utils import hash_args, \
assert_merge_protocol
from itertools import repeat
def merge(stream_in, tnfm_ids):
"""
@@ -17,7 +17,7 @@ def merge(stream_in, tnfm_ids):
and merge them together into an event. We raise an error if we
do not receive the same number of events from all sources.
"""
assert isinstance(tnfm_ids, list)
# Set up an internal queue for each expected source.
@@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids):
# Process incoming streams.
for message in stream_in:
assert isinstance(message, tuple), \
"Bad message in merge: %s" %message
assert len(message) == 2
id, value = message
assert isinstance(message, ndict)
assert message.has_key('tnfm_id')
assert message.has_key('tnfm_value')
assert message.has_key('dt')
id = message.tnfm_id
assert id in tnfm_ids, \
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
assert isinstance(value, ndict), "Bad message in merge: %s" %message
tnfms[id].append(value)
tnfms[id].append(message)
# Only pop messages when we have a pending message from
# all datasources. Stop if all sources have signalled done.
while ready(tnfms) and not done(tnfms):
message = merge_one(tnfms)
assert_merge_protocol(tnfm_ids, message)
yield message
# We should have only a done message left in each queue.
@@ -53,11 +53,22 @@ def merge(stream_in, tnfm_ids):
"Bad last message in merge on exit: %s" % queue
def merge_one(sources):
output = ndict()
dict_primer = zip(sources.keys(), repeat(None))
transforms = ndict(dict_primer)
event_fields = ndict()
for key, queue in sources.iteritems():
new_xform = ndict({key: queue.popleft()})
output.merge(new_xform)
return output
# Add transform value to the transforms dict.
message = queue.popleft()
transforms[message.tnfm_id] = message.tnfm_value
del message['tnfm_id']
del message['tnfm_value']
# Merge any remaining fields into the event dict.
event_fields.merge(message)
return ndict({'event' : event_fields, 'tnfms' : transforms})
#TODO: This is replicated in sort. Probably should be one source file.
+2 -4
View File
@@ -106,10 +106,8 @@ def SpecificEquityTrades(*args, **config):
else:
filtered = unfiltered
# Add a done message to the end of the stream. For a live
# datasource this would be handled by the containing Component.
out = chain(filtered, [mock_done(namestring)])
return out
# Return the filtered event stream.
return filtered
def RandomEquityTrades(*args, **config):
# We shouldn't get any positional args.
+34 -15
View File
@@ -3,6 +3,7 @@ Generator versions of transforms.
"""
import types
from copy import deepcopy
from datetime import datetime
from collections import deque, defaultdict
from numbers import Number
@@ -12,6 +13,7 @@ from zipline.gens.utils import assert_sort_unframe_protocol, \
assert_transform_protocol, hash_args
class Passthrough(object):
FORWARDER = True
"""
Trivial class for forwarding events.
"""
@@ -19,10 +21,7 @@ class Passthrough(object):
pass
def update(self, event):
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
return event
pass
def functional_transform(stream_in, func, *args, **kwargs):
"""
@@ -44,9 +43,13 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
"""
Generic transform generator that takes each message from an in-stream
and passes it to a state class. For each call to update, the state
class must produce a message to be fed downstream.
class must produce a message to be fed downstream. Any transform class
with the FORWARDER class variable set to true will forward all fields
in the original message. Otherwise only dt, tnfm_id, and tnfm_value
are forwarded.
"""
forward_all_fields = tnfm_class.__dict__.get('FORWARDER', False)
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
"Stateful transform requires a class."
assert tnfm_class.__dict__.has_key('update'), \
@@ -58,11 +61,31 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
# Generate the string associated with this generator's output.
namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
for message in stream_in:
assert_sort_unframe_protocol(message)
out_value = state.update(message)
assert_transform_protocol(out_value)
yield (namestring, out_value)
message_copy = deepcopy(message)
# Same shared pointer issue here as above.
tnfm_value = state.update(deepcopy(message_copy))
# If we want to keep all original values, just append tnfm_id
# and tnfm_value.
if forward_all_fields:
out_message = message_copy
out_message.tnfm_id = namestring
out_message.tnfm_value = tnfm_value
yield out_message
# Otherwise send tnfm_id, tnfm_value, and the message date.
else:
out_message = ndict()
out_message.tnfm_id = namestring
out_message.tnfm_value = tnfm_value
out_message.dt = message_copy.dt
yield out_message
class MovingAverage(object):
"""
@@ -70,6 +93,7 @@ class MovingAverage(object):
Upon receipt of each message we update the
corresponding window and return the calculated average.
"""
FORWARDER = False
def __init__(self, delta, fields):
self.delta = delta
@@ -93,16 +117,11 @@ class MovingAverage(object):
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
output = ndict({'sid': event.sid, 'dt': event.dt})
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
averages = window.get_averages()
# Return the calculated averages along with
output.merge(averages)
return output
return window.get_averages()
class EventWindow(object):
"""
+14 -5
View File
@@ -1,6 +1,7 @@
import pytz
import numbers
from collections import OrderedDict
from hashlib import md5
from datetime import datetime
from itertools import izip_longest
@@ -16,8 +17,16 @@ def mock_raw_event(sid, dt):
}
return event
def mock_done(source_id):
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
def mock_done(id):
return ndict({
'dt' : "DONE",
"source_id" : id,
'tnfm_id' : id,
'tnfm_value': None,
'type' : 0
})
done_message = mock_done
def alternate(g1, g2):
"""Specialized version of roundrobin for just 2 generators."""
@@ -36,14 +45,14 @@ def roundrobin(sources, namestrings):
mapping = OrderedDict(zip(namestrings, sources))
# While our generators have not been exhausted, pull elements
while mapping != []:
for namestring, source in mapping:
while mapping.keys() != []:
for namestring, source in mapping.iteritems():
try:
message = source.next()
yield message
except StopIteration:
yield done_message(namestring)
del mapping(namestring)
del mapping[namestring]