Merge branch 'byebye_threadsim' of github.com:quantopian/zipline into byebye_threadsim

This commit is contained in:
fawce
2012-07-30 13:53:09 -04:00
11 changed files with 367 additions and 135 deletions
+1 -1
View File
@@ -83,7 +83,7 @@ class RandomEquityTrades(TradeDataSource):
class SpecificEquityTrades(TradeDataSource):
"""
Generates a random stream of trades for testing.
Generates a non-random stream of trades for testing.
"""
def init(self, event_list):
+76 -13
View File
@@ -1,17 +1,80 @@
import datetime
from itertools import tee
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.utils import roundrobin, hash_args
from zipline.gens.feed import FeedGen
from zipline.gens.tradegen import SpecificEquityTrades
from zipline.gens.transform
from zipline.gens.merge import MergeGen
from zipline.gens.transform import StatefulTransformGen
def PreTransformLayer(sources):
"""A generator that takes a list of sources and runs their output
through a FeedGen."""
not_finished = len_
while not_finished:
def PreTransformLayer(sources, source_args, source_kwargs):
"""
Takes a list of generator functions, a list of tuples of positional arguments,
and a list of dictionaries of keyword arguments. Packages up all arguments
and passes them into a FeedGen.
"""
assert len(sources) == len(source_args) == len(source_kwargs)
# Package up sources and arguments.
arg_bundles = zip(sources, source_args, source_kwargs)
# Calculate namestring hashes to pass to FeedGen.
namestrings = [source.__name__ + hash_args(*args, **kwargs)
for source, args, kwargs in arg_bundles]
# Pass each source its arguments.
initialized = tuple(source(*args, **kwargs)
for source, args, kwargs in arg_bundles)
stream_in = roundrobin(*initialized)
return FeedGen(stream_in, namestrings)
def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs):
"""
A generator that takes the expected output of a FeedGen, pipes it
through a given set of transforms, and runs the results throught a
MergeGen to output a unified stream. tnfms should be a list of
pointers to generator functions. tnfm_args should be a list of
tuples, representing the arguments to be passed to each transform.
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
"""
# We should have as many sets of args as we have transforms.
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
# Create a copy of the stream for each transform.
split = tee(feed_stream, len(tnfms))
# Package each stream copy with a transform and set of args. Use a list
# so that we can re-use this for calculating hashes.
bundles = zip(split, tnfms, tnfm_args, tnfm_kwargs)
tnfm_gens = [StatefulTransformGen(stream, tnfm, *args, **kwargs)
for stream, tnfm, args, kwargs in bundles]
# Generate expected hashes for each transform
hashes = [tnfm.__name__ + hash_args(*args, **kwargs)
for _, tnfm, args, kwargs in bundles]
# Roundrobin the outputs of our transforms to create a single flat stream.
to_merge = roundrobin(*tnfm_gens)
# Pipe the stream into MergeGen.
merged = MergeGen(to_merge, hashes)
return merged
if __name__ == "__main__":
from zipline.gens.transform import MovingAverage, Passthrough
import nose.tools; nose.tools.set_trace()
source = SpecificEquityTrades
feed_out = PreTransformLayer((source,), ((),), ({},))
transforms = [MovingAverage, Passthrough]
args = [(datetime.timedelta(days = 1), ['price']), ()]
kwargs = [{}, {}]
tlayer = TransformLayer(feed_out, transforms, args, kwargs)
+8 -6
View File
@@ -13,7 +13,8 @@ from collections import deque, defaultdict
from zipline import ndict
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
assert_trade_protocol, assert_datasource_unframe_protocol
assert_trade_protocol, assert_datasource_unframe_protocol, \
assert_feed_protocol
import zipline.protocol as zp
@@ -38,7 +39,7 @@ def FeedGen(stream_in, source_ids):
# Incoming messages should be the output of DATASOURCE_UNFRAME.
assert_datasource_unframe_protocol(message), \
"Bad message in FeedGen: %s" % message
# Only allow messages from sources we expect.
assert message.source_id in sources, "Unexpected source: %s" % message
@@ -49,9 +50,9 @@ def FeedGen(stream_in, source_ids):
while full(sources) and not done(sources):
message = pop_oldest(sources)
assert feed_protocol(message)
assert_feed_protocol(message)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue
@@ -60,8 +61,9 @@ def FeedGen(stream_in, source_ids):
def full(sources):
"""
Feed is full when every internal queue has at least one message. Note that
this include DONE messages, so done(sources) is True only if full(sources).
Feed is full when every internal queue has at least one
message. Note that this include DONE messages, so done(sources) is
True only if full(sources).
"""
assert isinstance(sources, dict)
return all( (queue_is_full(source) for source in sources.itervalues()) )
+21 -19
View File
@@ -13,7 +13,7 @@ from collections import deque, defaultdict
from zipline import ndict
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
assert_trade_protocol, assert_datasource_unframe_protocol
assert_trade_protocol, assert_datasource_unframe_protocol, assert_merge_protocol
import zipline.protocol as zp
@@ -25,44 +25,46 @@ def MergeGen(stream_in, tnfm_ids):
and merge them together into an event. We raise an error if we
do not receive the same number of events from all sources.
"""
assert isinstance(source_ids, list)
assert isinstance(tnfm_ids, list)
# Set up an internal queue for each expected source.
sources = {}
for id in source_ids:
assert isinstance(id, basestring), "Bad source_id %s" % source_id
sources[id] = deque()
tnfms = {}
for id in tnfm_ids:
assert isinstance(id, basestring), "Bad source_id %s" % id
tnfms[id] = deque()
# Process incoming streams.
for message in stream_in:
assert isinstance(message, ndict), \
assert isinstance(message, tuple), \
"Bad message in MergeGen: %s" %message
assert message.tnfm_id in tnfm_ids, \
"Message from unexpected tnfm: %s, %s" % (message, tnfm_ids)
assert len(message) == 2
id, value = message
assert id in tnfm_ids, \
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
assert isinstance(value, ndict), "Bad message in MergeGen: %s" %message
assert message.has_key('value')
source[message.tnfm_id].append(message)
tnfms[id].append(value)
# Only pop messages when we have a pending message from
# all datasources. Stop if all sources have signalled done.
while full(sources) and not done(sources):
message = merge_one(sources)
assert merge_protocol(message)
while full(tnfms) and not done(tnfms):
message = merge_one(tnfms)
assert_merge_protocol(tnfm_ids, message)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
for queue in tnfms.itervalues():
assert len(queue) == 1, "Bad queue in MergeGen on exit: %s" % queue
assert queue[0].dt == "DONE", \
"Bad last message in MergeGen on exit: %s" % queue
def merge_one(sources):
output = ndict()
for queue in sources.itervalues():
output.merge(queue.popleft())
for key, queue in sources.iteritems():
new_xform = ndict({key: queue.popleft()})
output.merge(new_xform)
return output
+65 -16
View File
@@ -13,8 +13,10 @@ from collections import deque
from zipline import ndict
from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\
pop_oldest
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
assert_trade_protocol, date_gen, alternate
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
assert_trade_protocol, alternate
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import PreTransformLayer
import zipline.protocol as zp
@@ -98,12 +100,11 @@ class FeedGenTestCase(TestCase):
l = list(feed_gen)
assert l == expected
def test_single_source(self):
source_ids = ['a']
# 100 events, increasing by a minute at a time.
type = zp.DATASOURCE_TYPE.TRADE
dates = list(date_gen(n = 1))
dates = list(date_gen(count = 100))
dates.append("DONE")
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
@@ -125,7 +126,7 @@ class FeedGenTestCase(TestCase):
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
delta_a = timedelta(minutes = 2)
dates_a = list(date_gen(delta = delta_a, n = 20))
dates_a = list(date_gen(delta = delta_a, count = 20))
dates_a.append("DONE")
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
@@ -133,7 +134,7 @@ class FeedGenTestCase(TestCase):
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
delta_b = timedelta(minutes = 1)
dates_b = list(date_gen(delta = delta_b, n = 10))
dates_b = list(date_gen(delta = delta_b, count = 10))
dates_b.append("DONE")
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
@@ -152,13 +153,65 @@ class FeedGenTestCase(TestCase):
sequential = chain(iter(events_a), iter(events_b))
self.run_FeedGen(sequential, expected, source_ids)
def test_full_feed_layer(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
kwargs_a = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(hours = 1),
'filter' : filter
}
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
# Set up source d. This should produce no events because the
# internal sids don't match the filter.
args_d = tuple()
kwargs_d = {'sids' : [3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
sources = (SpecificEquityTrades,) * 4
source_args = (args_a, args_b, args_c, args_d)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
# Generate our expected source_ids.
zip_args = zip(source_args, source_kwargs)
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
for args, kwargs in zip_args]
# Pipe our sources into feed.
feed_out = PreTransformLayer(sources, source_args, source_kwargs)
# Read all the values from feed and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(feed_out)
copy = to_list[:]
for e in to_list:
# All events should match one of our expected source_ids.
assert e.source_id in expected_ids
# But none of them should match source_d.
assert e.source_id != hash_args(*args_d, **kwargs_d)
expected = sorted(copy, compare_by_dt_source_id)
assert to_list == expected
def test_with_specific_equity(self):
def mock_data_unframe(source_id, dt, type):
event = ndict()
event.source_id = source_id
@@ -182,7 +235,3 @@ def compare_by_dt_source_id(x,y):
else:
return 0
+2 -2
View File
@@ -10,7 +10,7 @@ from itertools import izip, izip_longest
from datetime import datetime, timedelta
from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
assert_trade_protocol, mock_raw_event
import zipline.protocol as zp
@@ -107,7 +107,7 @@ class TestMongoDataGenerator(TestCase):
for field in iter(['sid', 'dt', 'price', 'volume']):
assert db[field] == expected[field]
# Expected output of stringify_args:
# Expected output of hash_args:
assert db['source_id'] == \
'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc'
+108 -26
View File
@@ -1,44 +1,126 @@
import random
from itertools import chain, repeat, cycle, ifilter
from itertools import chain, repeat, cycle, ifilter, izip
from datetime import datetime, timedelta
from zipline.utils.factory import create_trade, create_trade
from zipline.gens.utils import date_gen
from zipline.utils.factory import create_trade
from zipline.gens.utils import hash_args, mock_done
def mock_prices(n, rand = False):
"""Utility to generate a set of prices. By default
cycles through values from 0.0 to 10.0 n times. Optional
flag to give random values between 0.0 and 10.0"""
def date_gen(start = datetime(2012, 6, 6, 0),
delta = timedelta(minutes = 1),
count = 100):
"""
Utility to generate a stream of dates.
"""
return (start + (i * delta) for i in xrange(count))
def mock_prices(count, rand = False):
"""
Utility to generate a stream of mock prices. By default
cycles through values from 0.0 to 10.0, n times. Optional
flag to give random values between 0.0 and 10.0
"""
if rand:
return (random.uniform(0.0, 10.0) for i in xrange(n))
return (random.uniform(0.0, 10.0) for i in xrange(count))
else:
return (float(i % 11) for i in xrange(1,n+1))
return (float(i % 11) for i in xrange(1,count+1))
def mock_volumes(n, rand = False):
"""Does the same as mock_prices. Different function name
for readability."""
return mock_prices(n, rand)
def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = None):
"""Returns the first n events of event_list if specified.
Otherwise generates a sensible stream of events."""
def mock_volumes(count, rand = False):
"""
Utility to generate a set of volumes. By default cycles
through values from 100 to 1000, incrementing by 50. Optional
flag to give random values between 100 and 1000.
"""
if rand:
return (random.randrange(100, 1000) for i in xrange(count))
else:
return ((i * 50)%900 + 100 for i in xrange(count))
def fuzzy_dates(count = 500):
"""
Add +-10 seconds to each event from a date_gen. Note that this
still guarantees sorting, since the default on date_gen is minute
separation of events.
"""
for date in date_gen(count = count):
yield date + timedelta(seconds = random.randint(-10, 10))
def SpecificEquityTrades(*args, **config):
"""
Yields all events in event_list that match the given sid_filter.
If no event_list is specified, generates an internal stream of events
to filter. Returns all events if filter is None.
"""
# We shouldn't get any positional arguments.
assert args == ()
# Unpack config dictionary with default values.
count = config.get('count', 500)
sids = config.get('sids', [1, 2])
start = config.get('start', datetime(2012, 6, 6, 0))
delta = config.get('delta', timedelta(minutes = 1))
# Default to None for event_list and filter.
event_list = config.get('event_list')
filter = config.get('filter')
arg_string = hash_args(*args, **config)
namestring = "SpecificEquityTrades" + arg_string
# If we have an event_list, ignore the other arguments and use the list.
# TODO: still append our namestring?
if event_list:
unfiltered = (event for event in event_list)
# Set up iterators for each expected field.
else:
dates = date_gen(n = n)
prices = mock_prices(n)
volumes = mock_volumes(n)
sids = cycle(iter(sids))
dates = date_gen(count = count, start = start, delta = delta)
prices = mock_prices(count)
volumes = mock_volumes(count)
sids = cycle(sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
unfiltered = (create_trade(*args) for args in arg_gen)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id = namestring)
for args in arg_gen)
# If we specified a sid filter, filter out elements that don't match the filter.
if filter:
filtered = ifilter(lambda event: event.sid in filter)
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
# Otherwise just use all events.
else:
filtered = unfiltered
# Add a done message to the end of the stream. For a live
# datasource this would be handled by the containing Component.
out = chain(filtered, [mock_done(namestring)])
return out
def RandomEquityTrades(*args, **config):
# We shouldn't get any positional args.
assert args == ()
count = config.get('count', 500)
sids = config.get('sids', [1,2])
filter = config.get('filter')
dates = fuzzy_dates(count)
prices = mock_prices(count, rand = True)
volumes = mock_volumes(count, rand = True)
sids = cycle(sids)
arg_gen = izip(sids, prices, volumes, dates)
unfiltered = (create_trade(*args) for args in arg_gen)
if filter:
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
else:
filtered = unfiltered
return filtered
# if __name__ == "__main__":
# import nose.tools; nose.tools.set_trace()
# trades = SpecificEquityTrades(filter = [1])
+45 -44
View File
@@ -14,26 +14,24 @@ from numbers import Number
from itertools import izip
from zipline import ndict
from zipline.gens.utils import hash_args, date_gen
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
from zipline.gens.tradegens import date_gen
from zipline.gens.utils import assert_feed_unframe_protocol, \
assert_transform_protocol, hash_args
import zipline.protocol as zp
def PassthroughTransformGen(stream_in):
"""Trivial transform for event forwarding."""
class Passthrough(object):
"""
Trivial function for forwarding events.
"""
def __init__(self):
pass
# hash_args with no arguments is the same as:
# hasher = hashlib.md5()
# hasher.update(":");
# hashlib.md5.digest().
namestring = "Passthrough" + hash_args()
for message in stream_in:
assert_feed_unframe_protocol(message)
out_value = message
assert_transform_protocol(out_value)
yield (namestring, out_value)
def update(self, event):
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
return event
def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
"""
@@ -43,6 +41,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
"""
# TODO: Distinguish between functions and classes in hash_args.
# As implemented we will get assertion errors if a function and
# stateful class have the same name, which may or may not be
# what we want.
namestring = fun.__name__ + hash_args(*args, **kwargs)
for message in stream_in:
@@ -70,13 +72,6 @@ def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs):
assert_transform_protocol(out_value)
yield (namestring, out_value)
def MovingAverageTransformGen(stream_in, days, fields):
"""
Generator that uses the MovingAverage state class to calculate
a moving average for all stocks over a specified number of days.
"""
return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields)
class MovingAverage(object):
"""
Class that maintains a dictionary from sids to EventWindows
@@ -91,7 +86,7 @@ class MovingAverage(object):
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""Factory method for self.sid_windows."""
return EventWindow(self.delta, self.fields)
@@ -104,13 +99,18 @@ class MovingAverage(object):
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
output = ndict({'sid': event.sid, 'dt': event.dt})
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
averages = window.get_averages()
return window.get_averages()
# Return the calculated averages along with
output.merge(averages)
return output
class EventWindow(object):
"""
@@ -144,7 +144,7 @@ class EventWindow(object):
# newest oldest
# | |
# V V
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
# popleft removes and returns ticks[0]
popped = self.ticks.popleft()
@@ -168,6 +168,7 @@ class EventWindow(object):
Return an ndict of all our tracked averages.
"""
out = ndict()
# out.ticks = len(self.ticks)
for field in self.fields:
out[field] = self.average(field)
return out
@@ -186,26 +187,26 @@ class EventWindow(object):
assert isinstance(event[field], Number), \
"Got %s for %s in EventWindow" % (event[field], field)
if __name__ == "__main__":
# if __name__ == "__main__":
def make_event(**kwargs):
e = ndict()
for key, value in kwargs.iteritems():
e[key] = value
return e
# def make_event(**kwargs):
# e = ndict()
# for key, value in kwargs.iteritems():
# e[key] = value
# return e
dates = date_gen(delta = timedelta(hours = 12))
events = (
make_event(
sid = 'foo', price = random.random(),
dt = date,
type = zp.DATASOURCE_TYPE.TRADE,
source_id = 'ds',
vol = i
)
for date, i in izip(dates, xrange(100))
)
# dates = date_gen(delta = timedelta(hours = 12))
# events = (
# make_event(
# sid = 'foo', price = random.random(),
# dt = date,
# type = zp.DATASOURCE_TYPE.TRADE,
# source_id = 'ds',
# vol = i
# )
# for date, i in izip(dates, xrange(100))
# )
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
# gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
+23 -6
View File
@@ -16,16 +16,28 @@ def mock_raw_event(sid, dt):
}
return event
def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100):
return (start + i * delta for i in xrange(n))
def mock_done(source_id):
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
def alternate(g1, g2):
"""Specialized version of roundrobin for just 2 generators."""
for e1, e2 in izip_longest(g1, g2):
if e1 != None:
yield e1
if e2 != None:
yield e2
def roundrobin(*args):
"""
Takes N generators, pulling one element off each until all inputs
are empty.
"""
for elem_tuple in izip_longest(*args):
for value in elem_tuple:
if value != None:
yield value
def hash_args(*args, **kwargs):
"""Define a unique string for any set of representable args."""
arg_string = '_'.join([str(arg) for arg in args])
@@ -73,14 +85,19 @@ def assert_feed_protocol(event):
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_feed_unframe_protocol(event):
"""Same as above."""
assert isinstance(event, ndict)
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_transform_protocol(event):
pass
"""Transforms should return an ndict to be merged by MergeGen."""
assert isinstance(event, ndict)
def assert_merge_protocol(tnfm_ids, message):
"""Merge should output an ndict with a field for each id in its transform set."""
assert isinstance(message, ndict)
assert set(tnfm_ids) == set(message.keys())
+16
View File
@@ -0,0 +1,16 @@
import zmq
import zipline.protocol as zp
def gen_from_zmq(poller, unframe):
"""
A generator that takes an initialized zmq poller and yields
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
"""
while True:
message = poller.recv()
if message = zp.CONTROL_PROTOCOL.DONE:
yield "DONE"
break
else:
yield unframe(message)
+2 -2
View File
@@ -69,9 +69,9 @@ def create_trading_environment(year=2006):
return trading_environment
def create_trade(sid, price, amount, datetime):
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
row = zp.ndict({
'source_id' : "test_factory",
'source_id' : source_id,
'type' : zp.DATASOURCE_TYPE.TRADE,
'sid' : sid,
'dt' : datetime,