updates for transforms

This commit is contained in:
scottsanderson
2012-07-29 19:56:10 -04:00
parent e048e8bc35
commit fe1740a3ce
5 changed files with 124 additions and 70 deletions
+54 -10
View File
@@ -1,18 +1,62 @@
from zipline.gens.utils import roundrobin
from itertools import tee
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.utils import roundrobin, hash_args
from zipline.gens.feed import FeedGen
from zipline.gens.merge import MergeGen
def PreTransformLayer(sources, source_args, source_kwargs):
"""
Takes a list of generator functions, a list of tuples of positional arguments,
and a list of dictionaries of keyword arguments. Packages up all arguments
and passes them into a FeedGen.
"""
assert len(sources) == len(source_args) == len(source_kwargs)
# Package up sources and arguments.
arg_bundles = zip(sources, source_args, source_kwargs)
# Calculate namestring hashes to pass to FeedGen.
namestrings = [source.__name__ + hash_args(*args, **kwargs)
for source, args, kwargs in arg_bundles]
# Pass each source its arguments.
initialized = tuple(source(*args, **kwargs)
for source, args, kwargs in arg_bundles)
def PreTransformLayer(sources, source_ids):
"""
A generator that takes a tuple of sources and a list ids, piping
their output into a feed_gen.
"""
stream_in = roundrobin(*sources)
stream_in = roundrobin(*initialized)
return FeedGen(stream_in, source_ids)
def TransformLayer(feed_stream, tnfms):
""" """
pass
def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs):
"""
A generator that takes the expected output of a FeedGen, pipes it
through a given set of transforms, and runs the results throught a
MergeGen to output a unified stream. tnfms should be a list of
pointers to generator functions. tnfm_args should be a list of
tuples, representing the arguments to be passed to each transform.
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
"""
# We should have as many sets of args as we have transforms.
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
# Create a copy of the stream for each transform.
split = tee(feed_stream, len(tnfms))
# Package each stream copy with a transform and set of args. Use a list
# so that we can re-use this for calculating hashes.
bundles = zip(split, iter(tnfms), iter(tnfm_args), iter(tnfm_kwargs))
# Convert the argument bundles into a tuple of transform objects.
transformed = tuple((tnfm(stream, *args, **kwargs)
for stream, tnfm, args, kwargs in iter(bundles)))
# Roundrobin the outputs of our transforms to create a single flat stream.
to_merge = roundrobin(*transformed)
merged = MergeGen()
if __name__ == "__main__":
source = SpecificEquityTrades()
-1
View File
@@ -53,7 +53,6 @@ def FeedGen(stream_in, source_ids):
assert_feed_protocol(message)
yield message
import nose.tools; nose.tools.set_trace()
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue
+30 -32
View File
@@ -153,39 +153,41 @@ class FeedGenTestCase(TestCase):
sequential = chain(iter(events_a), iter(events_b))
self.run_FeedGen(sequential, expected, source_ids)
def test_full_feed_layer(self):
filter = [1,2]
#Set up source a.
args_a = tuple()
kwargs_a = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
#Set up source b.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,5],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
#Set up source c.
args_c = tuple()
kwargs_c = {'sids' : [1,2,3,5],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
source_a = SpecificEquityTrades(sids = [1,2,3,4],
start = datetime(2012,6,6,0),
delta = timedelta(minutes=1),
filter = filter
)
id_a = "SpecificEquityTradesd175237b28d2f52df208c97cf4af896e"
# Change the internal sid list to give us a different hash.
source_b = SpecificEquityTrades(sids = [1,2,3,5],
start = datetime(2012,6,6,0),
delta = timedelta(minutes=1),
filter = filter
)
sources = tuple(SpecificEquityTrades) * 3
source_args = (args_a, args_b, args_c)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c)
id_b = 'SpecificEquityTrades2bf2c2d6d01d4dbfc0b2818438ea8151'
# Change the internal sid list to give us a different hash.
source_c = SpecificEquityTrades(sids = [1,2,3,6],
start = datetime(2012,6,6,0),
delta = timedelta(minutes=1),
filter = filter
)
id_c = 'SpecificEquityTrades16f7437db2d14e5373ef20025f49a3fe'
sources = (source_a, source_b, source_c)
source_ids = [id_a, id_b, id_c]
feed_out = PreTransformLayer(sources, source_args, source_kwargs)
to_list = list(feed_out)
copy = to_list[:]
expected = sorted(copy, compare_by_dt_source_id)
feed_out = PreTransformLayer(sources, source_ids)
l = list(feed_out)
assert to_list == expected
def mock_data_unframe(source_id, dt, type):
event = ndict()
@@ -210,7 +212,3 @@ def compare_by_dt_source_id(x,y):
else:
return 0
+37 -24
View File
@@ -14,8 +14,9 @@ from numbers import Number
from itertools import izip
from zipline import ndict
from zipline.gens.utils import hash_args, date_gen
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
from zipline.gens.tradegens import date_gen
from zipline.gens.utils import assert_feed_unframe_protocol, \
assert_transform_protocol, hash_args
import zipline.protocol as zp
@@ -43,6 +44,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
"""
# TODO: Distinguish between functions and classes in hash_args.
# As implemented we will get assertion errors if a function and
# stateful class have the same name, which may or may not be
# what we want.
namestring = fun.__name__ + hash_args(*args, **kwargs)
for message in stream_in:
@@ -75,7 +80,10 @@ def MovingAverageTransformGen(stream_in, days, fields):
Generator that uses the MovingAverage state class to calculate
a moving average for all stocks over a specified number of days.
"""
return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields)
return StatefulTransformGen(stream_in,
MovingAverage,
timedelta(days=days),
fields)
class MovingAverage(object):
"""
@@ -91,7 +99,7 @@ class MovingAverage(object):
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""Factory method for self.sid_windows."""
return EventWindow(self.delta, self.fields)
@@ -105,12 +113,16 @@ class MovingAverage(object):
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
output = ndict({'sid': event.sid})
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
averages = window.get_averages()
return window.get_averages()
# Return the calculated averages along with
output.merge(averages)
return output
class EventWindow(object):
"""
@@ -144,7 +156,7 @@ class EventWindow(object):
# newest oldest
# | |
# V V
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
# popleft removes and returns ticks[0]
popped = self.ticks.popleft()
@@ -168,6 +180,7 @@ class EventWindow(object):
Return an ndict of all our tracked averages.
"""
out = ndict()
# out.ticks = len(self.ticks)
for field in self.fields:
out[field] = self.average(field)
return out
@@ -186,26 +199,26 @@ class EventWindow(object):
assert isinstance(event[field], Number), \
"Got %s for %s in EventWindow" % (event[field], field)
if __name__ == "__main__":
# if __name__ == "__main__":
def make_event(**kwargs):
e = ndict()
for key, value in kwargs.iteritems():
e[key] = value
return e
# def make_event(**kwargs):
# e = ndict()
# for key, value in kwargs.iteritems():
# e[key] = value
# return e
dates = date_gen(delta = timedelta(hours = 12))
events = (
make_event(
sid = 'foo', price = random.random(),
dt = date,
type = zp.DATASOURCE_TYPE.TRADE,
source_id = 'ds',
vol = i
)
for date, i in izip(dates, xrange(100))
)
# dates = date_gen(delta = timedelta(hours = 12))
# events = (
# make_event(
# sid = 'foo', price = random.random(),
# dt = date,
# type = zp.DATASOURCE_TYPE.TRADE,
# source_id = 'ds',
# vol = i
# )
# for date, i in izip(dates, xrange(100))
# )
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
# gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
+3 -3
View File
@@ -85,14 +85,14 @@ def assert_feed_protocol(event):
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_feed_unframe_protocol(event):
"""Same as above."""
assert isinstance(event, ndict)
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_transform_protocol(event):
pass
"""Transforms should return an ndict to be merged by MergeGen."""
assert isinstance(event, ndict)