mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 05:15:44 +08:00
updates for transforms
This commit is contained in:
+54
-10
@@ -1,18 +1,62 @@
|
||||
from zipline.gens.utils import roundrobin
|
||||
from itertools import tee
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import roundrobin, hash_args
|
||||
from zipline.gens.feed import FeedGen
|
||||
from zipline.gens.merge import MergeGen
|
||||
|
||||
def PreTransformLayer(sources, source_args, source_kwargs):
|
||||
"""
|
||||
Takes a list of generator functions, a list of tuples of positional arguments,
|
||||
and a list of dictionaries of keyword arguments. Packages up all arguments
|
||||
and passes them into a FeedGen.
|
||||
"""
|
||||
assert len(sources) == len(source_args) == len(source_kwargs)
|
||||
# Package up sources and arguments.
|
||||
arg_bundles = zip(sources, source_args, source_kwargs)
|
||||
# Calculate namestring hashes to pass to FeedGen.
|
||||
namestrings = [source.__name__ + hash_args(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles]
|
||||
# Pass each source its arguments.
|
||||
initialized = tuple(source(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles)
|
||||
|
||||
def PreTransformLayer(sources, source_ids):
|
||||
"""
|
||||
A generator that takes a tuple of sources and a list ids, piping
|
||||
their output into a feed_gen.
|
||||
"""
|
||||
stream_in = roundrobin(*sources)
|
||||
stream_in = roundrobin(*initialized)
|
||||
return FeedGen(stream_in, source_ids)
|
||||
|
||||
def TransformLayer(feed_stream, tnfms):
|
||||
""" """
|
||||
pass
|
||||
def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
"""
|
||||
A generator that takes the expected output of a FeedGen, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
MergeGen to output a unified stream. tnfms should be a list of
|
||||
pointers to generator functions. tnfm_args should be a list of
|
||||
tuples, representing the arguments to be passed to each transform.
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
"""
|
||||
# We should have as many sets of args as we have transforms.
|
||||
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
|
||||
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(feed_stream, len(tnfms))
|
||||
|
||||
# Package each stream copy with a transform and set of args. Use a list
|
||||
# so that we can re-use this for calculating hashes.
|
||||
bundles = zip(split, iter(tnfms), iter(tnfm_args), iter(tnfm_kwargs))
|
||||
|
||||
# Convert the argument bundles into a tuple of transform objects.
|
||||
transformed = tuple((tnfm(stream, *args, **kwargs)
|
||||
for stream, tnfm, args, kwargs in iter(bundles)))
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(*transformed)
|
||||
|
||||
merged = MergeGen()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
source = SpecificEquityTrades()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ def FeedGen(stream_in, source_ids):
|
||||
assert_feed_protocol(message)
|
||||
yield message
|
||||
|
||||
import nose.tools; nose.tools.set_trace()
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue
|
||||
|
||||
+30
-32
@@ -153,39 +153,41 @@ class FeedGenTestCase(TestCase):
|
||||
|
||||
sequential = chain(iter(events_a), iter(events_b))
|
||||
self.run_FeedGen(sequential, expected, source_ids)
|
||||
|
||||
|
||||
def test_full_feed_layer(self):
|
||||
filter = [1,2]
|
||||
#Set up source a.
|
||||
args_a = tuple()
|
||||
kwargs_a = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,5],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source c.
|
||||
args_c = tuple()
|
||||
kwargs_c = {'sids' : [1,2,3,5],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
|
||||
source_a = SpecificEquityTrades(sids = [1,2,3,4],
|
||||
start = datetime(2012,6,6,0),
|
||||
delta = timedelta(minutes=1),
|
||||
filter = filter
|
||||
)
|
||||
id_a = "SpecificEquityTradesd175237b28d2f52df208c97cf4af896e"
|
||||
|
||||
# Change the internal sid list to give us a different hash.
|
||||
source_b = SpecificEquityTrades(sids = [1,2,3,5],
|
||||
start = datetime(2012,6,6,0),
|
||||
delta = timedelta(minutes=1),
|
||||
filter = filter
|
||||
)
|
||||
sources = tuple(SpecificEquityTrades) * 3
|
||||
source_args = (args_a, args_b, args_c)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c)
|
||||
|
||||
id_b = 'SpecificEquityTrades2bf2c2d6d01d4dbfc0b2818438ea8151'
|
||||
|
||||
# Change the internal sid list to give us a different hash.
|
||||
source_c = SpecificEquityTrades(sids = [1,2,3,6],
|
||||
start = datetime(2012,6,6,0),
|
||||
delta = timedelta(minutes=1),
|
||||
filter = filter
|
||||
)
|
||||
id_c = 'SpecificEquityTrades16f7437db2d14e5373ef20025f49a3fe'
|
||||
|
||||
sources = (source_a, source_b, source_c)
|
||||
source_ids = [id_a, id_b, id_c]
|
||||
feed_out = PreTransformLayer(sources, source_args, source_kwargs)
|
||||
to_list = list(feed_out)
|
||||
copy = to_list[:]
|
||||
expected = sorted(copy, compare_by_dt_source_id)
|
||||
|
||||
feed_out = PreTransformLayer(sources, source_ids)
|
||||
l = list(feed_out)
|
||||
assert to_list == expected
|
||||
|
||||
def mock_data_unframe(source_id, dt, type):
|
||||
event = ndict()
|
||||
@@ -210,7 +212,3 @@ def compare_by_dt_source_id(x,y):
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+37
-24
@@ -14,8 +14,9 @@ from numbers import Number
|
||||
from itertools import izip
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, date_gen
|
||||
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
|
||||
from zipline.gens.tradegens import date_gen
|
||||
from zipline.gens.utils import assert_feed_unframe_protocol, \
|
||||
assert_transform_protocol, hash_args
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -43,6 +44,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
"""
|
||||
|
||||
# TODO: Distinguish between functions and classes in hash_args.
|
||||
# As implemented we will get assertion errors if a function and
|
||||
# stateful class have the same name, which may or may not be
|
||||
# what we want.
|
||||
|
||||
namestring = fun.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
for message in stream_in:
|
||||
@@ -75,7 +80,10 @@ def MovingAverageTransformGen(stream_in, days, fields):
|
||||
Generator that uses the MovingAverage state class to calculate
|
||||
a moving average for all stocks over a specified number of days.
|
||||
"""
|
||||
return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields)
|
||||
return StatefulTransformGen(stream_in,
|
||||
MovingAverage,
|
||||
timedelta(days=days),
|
||||
fields)
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
@@ -91,7 +99,7 @@ class MovingAverage(object):
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
|
||||
def create_window(self):
|
||||
"""Factory method for self.sid_windows."""
|
||||
return EventWindow(self.delta, self.fields)
|
||||
@@ -105,12 +113,16 @@ class MovingAverage(object):
|
||||
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
|
||||
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
|
||||
|
||||
output = ndict({'sid': event.sid})
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
averages = window.get_averages()
|
||||
|
||||
return window.get_averages()
|
||||
# Return the calculated averages along with
|
||||
output.merge(averages)
|
||||
return output
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
@@ -144,7 +156,7 @@ class EventWindow(object):
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
|
||||
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
|
||||
# popleft removes and returns ticks[0]
|
||||
popped = self.ticks.popleft()
|
||||
@@ -168,6 +180,7 @@ class EventWindow(object):
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
# out.ticks = len(self.ticks)
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
return out
|
||||
@@ -186,26 +199,26 @@ class EventWindow(object):
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# if __name__ == "__main__":
|
||||
|
||||
def make_event(**kwargs):
|
||||
e = ndict()
|
||||
for key, value in kwargs.iteritems():
|
||||
e[key] = value
|
||||
return e
|
||||
# def make_event(**kwargs):
|
||||
# e = ndict()
|
||||
# for key, value in kwargs.iteritems():
|
||||
# e[key] = value
|
||||
# return e
|
||||
|
||||
dates = date_gen(delta = timedelta(hours = 12))
|
||||
events = (
|
||||
make_event(
|
||||
sid = 'foo', price = random.random(),
|
||||
dt = date,
|
||||
type = zp.DATASOURCE_TYPE.TRADE,
|
||||
source_id = 'ds',
|
||||
vol = i
|
||||
)
|
||||
for date, i in izip(dates, xrange(100))
|
||||
)
|
||||
# dates = date_gen(delta = timedelta(hours = 12))
|
||||
# events = (
|
||||
# make_event(
|
||||
# sid = 'foo', price = random.random(),
|
||||
# dt = date,
|
||||
# type = zp.DATASOURCE_TYPE.TRADE,
|
||||
# source_id = 'ds',
|
||||
# vol = i
|
||||
# )
|
||||
# for date, i in izip(dates, xrange(100))
|
||||
# )
|
||||
|
||||
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
|
||||
# gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
|
||||
|
||||
|
||||
|
||||
@@ -85,14 +85,14 @@ def assert_feed_protocol(event):
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_feed_unframe_protocol(event):
|
||||
"""Same as above."""
|
||||
assert isinstance(event, ndict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_transform_protocol(event):
|
||||
pass
|
||||
"""Transforms should return an ndict to be merged by MergeGen."""
|
||||
assert isinstance(event, ndict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user