mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 01:53:27 +08:00
Merge branch 'byebye_threadsim' of github.com:quantopian/zipline into byebye_threadsim
This commit is contained in:
@@ -83,7 +83,7 @@ class RandomEquityTrades(TradeDataSource):
|
||||
|
||||
class SpecificEquityTrades(TradeDataSource):
|
||||
"""
|
||||
Generates a random stream of trades for testing.
|
||||
Generates a non-random stream of trades for testing.
|
||||
"""
|
||||
|
||||
def init(self, event_list):
|
||||
|
||||
+76
-13
@@ -1,17 +1,80 @@
|
||||
import datetime
|
||||
from itertools import tee
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import roundrobin, hash_args
|
||||
from zipline.gens.feed import FeedGen
|
||||
from zipline.gens.tradegen import SpecificEquityTrades
|
||||
from zipline.gens.transform
|
||||
from zipline.gens.merge import MergeGen
|
||||
from zipline.gens.transform import StatefulTransformGen
|
||||
|
||||
|
||||
|
||||
|
||||
def PreTransformLayer(sources):
|
||||
"""A generator that takes a list of sources and runs their output
|
||||
through a FeedGen."""
|
||||
not_finished = len_
|
||||
|
||||
while not_finished:
|
||||
|
||||
|
||||
def PreTransformLayer(sources, source_args, source_kwargs):
|
||||
"""
|
||||
Takes a list of generator functions, a list of tuples of positional arguments,
|
||||
and a list of dictionaries of keyword arguments. Packages up all arguments
|
||||
and passes them into a FeedGen.
|
||||
"""
|
||||
assert len(sources) == len(source_args) == len(source_kwargs)
|
||||
# Package up sources and arguments.
|
||||
arg_bundles = zip(sources, source_args, source_kwargs)
|
||||
|
||||
# Calculate namestring hashes to pass to FeedGen.
|
||||
namestrings = [source.__name__ + hash_args(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles]
|
||||
# Pass each source its arguments.
|
||||
initialized = tuple(source(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles)
|
||||
|
||||
stream_in = roundrobin(*initialized)
|
||||
return FeedGen(stream_in, namestrings)
|
||||
|
||||
|
||||
def TransformLayer(feed_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
"""
|
||||
A generator that takes the expected output of a FeedGen, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
MergeGen to output a unified stream. tnfms should be a list of
|
||||
pointers to generator functions. tnfm_args should be a list of
|
||||
tuples, representing the arguments to be passed to each transform.
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
"""
|
||||
|
||||
# We should have as many sets of args as we have transforms.
|
||||
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
|
||||
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(feed_stream, len(tnfms))
|
||||
|
||||
# Package each stream copy with a transform and set of args. Use a list
|
||||
# so that we can re-use this for calculating hashes.
|
||||
bundles = zip(split, tnfms, tnfm_args, tnfm_kwargs)
|
||||
|
||||
tnfm_gens = [StatefulTransformGen(stream, tnfm, *args, **kwargs)
|
||||
for stream, tnfm, args, kwargs in bundles]
|
||||
|
||||
# Generate expected hashes for each transform
|
||||
hashes = [tnfm.__name__ + hash_args(*args, **kwargs)
|
||||
for _, tnfm, args, kwargs in bundles]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(*tnfm_gens)
|
||||
|
||||
# Pipe the stream into MergeGen.
|
||||
merged = MergeGen(to_merge, hashes)
|
||||
return merged
|
||||
|
||||
if __name__ == "__main__":
|
||||
from zipline.gens.transform import MovingAverage, Passthrough
|
||||
|
||||
import nose.tools; nose.tools.set_trace()
|
||||
source = SpecificEquityTrades
|
||||
feed_out = PreTransformLayer((source,), ((),), ({},))
|
||||
|
||||
transforms = [MovingAverage, Passthrough]
|
||||
args = [(datetime.timedelta(days = 1), ['price']), ()]
|
||||
kwargs = [{}, {}]
|
||||
|
||||
tlayer = TransformLayer(feed_out, transforms, args, kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ from collections import deque, defaultdict
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol, \
|
||||
assert_feed_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -38,7 +39,7 @@ def FeedGen(stream_in, source_ids):
|
||||
# Incoming messages should be the output of DATASOURCE_UNFRAME.
|
||||
assert_datasource_unframe_protocol(message), \
|
||||
"Bad message in FeedGen: %s" % message
|
||||
|
||||
|
||||
# Only allow messages from sources we expect.
|
||||
assert message.source_id in sources, "Unexpected source: %s" % message
|
||||
|
||||
@@ -49,9 +50,9 @@ def FeedGen(stream_in, source_ids):
|
||||
|
||||
while full(sources) and not done(sources):
|
||||
message = pop_oldest(sources)
|
||||
assert feed_protocol(message)
|
||||
assert_feed_protocol(message)
|
||||
yield message
|
||||
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue
|
||||
@@ -60,8 +61,9 @@ def FeedGen(stream_in, source_ids):
|
||||
|
||||
def full(sources):
|
||||
"""
|
||||
Feed is full when every internal queue has at least one message. Note that
|
||||
this include DONE messages, so done(sources) is True only if full(sources).
|
||||
Feed is full when every internal queue has at least one
|
||||
message. Note that this include DONE messages, so done(sources) is
|
||||
True only if full(sources).
|
||||
"""
|
||||
assert isinstance(sources, dict)
|
||||
return all( (queue_is_full(source) for source in sources.itervalues()) )
|
||||
|
||||
+21
-19
@@ -13,7 +13,7 @@ from collections import deque, defaultdict
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol, \
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol
|
||||
assert_trade_protocol, assert_datasource_unframe_protocol, assert_merge_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -25,44 +25,46 @@ def MergeGen(stream_in, tnfm_ids):
|
||||
and merge them together into an event. We raise an error if we
|
||||
do not receive the same number of events from all sources.
|
||||
"""
|
||||
|
||||
assert isinstance(source_ids, list)
|
||||
|
||||
assert isinstance(tnfm_ids, list)
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
sources = {}
|
||||
for id in source_ids:
|
||||
assert isinstance(id, basestring), "Bad source_id %s" % source_id
|
||||
sources[id] = deque()
|
||||
tnfms = {}
|
||||
for id in tnfm_ids:
|
||||
assert isinstance(id, basestring), "Bad source_id %s" % id
|
||||
tnfms[id] = deque()
|
||||
|
||||
# Process incoming streams.
|
||||
for message in stream_in:
|
||||
assert isinstance(message, ndict), \
|
||||
assert isinstance(message, tuple), \
|
||||
"Bad message in MergeGen: %s" %message
|
||||
assert message.tnfm_id in tnfm_ids, \
|
||||
"Message from unexpected tnfm: %s, %s" % (message, tnfm_ids)
|
||||
assert len(message) == 2
|
||||
id, value = message
|
||||
assert id in tnfm_ids, \
|
||||
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
|
||||
assert isinstance(value, ndict), "Bad message in MergeGen: %s" %message
|
||||
|
||||
assert message.has_key('value')
|
||||
|
||||
source[message.tnfm_id].append(message)
|
||||
tnfms[id].append(value)
|
||||
|
||||
# Only pop messages when we have a pending message from
|
||||
# all datasources. Stop if all sources have signalled done.
|
||||
|
||||
while full(sources) and not done(sources):
|
||||
message = merge_one(sources)
|
||||
assert merge_protocol(message)
|
||||
while full(tnfms) and not done(tnfms):
|
||||
message = merge_one(tnfms)
|
||||
assert_merge_protocol(tnfm_ids, message)
|
||||
yield message
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
for queue in tnfms.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in MergeGen on exit: %s" % queue
|
||||
assert queue[0].dt == "DONE", \
|
||||
"Bad last message in MergeGen on exit: %s" % queue
|
||||
|
||||
def merge_one(sources):
|
||||
output = ndict()
|
||||
for queue in sources.itervalues():
|
||||
output.merge(queue.popleft())
|
||||
for key, queue in sources.iteritems():
|
||||
new_xform = ndict({key: queue.popleft()})
|
||||
output.merge(new_xform)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
+65
-16
@@ -13,8 +13,10 @@ from collections import deque
|
||||
from zipline import ndict
|
||||
from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\
|
||||
pop_oldest
|
||||
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
|
||||
assert_trade_protocol, date_gen, alternate
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
|
||||
assert_trade_protocol, alternate
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import PreTransformLayer
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -98,12 +100,11 @@ class FeedGenTestCase(TestCase):
|
||||
l = list(feed_gen)
|
||||
assert l == expected
|
||||
|
||||
|
||||
def test_single_source(self):
|
||||
source_ids = ['a']
|
||||
# 100 events, increasing by a minute at a time.
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
dates = list(date_gen(n = 1))
|
||||
dates = list(date_gen(count = 100))
|
||||
dates.append("DONE")
|
||||
|
||||
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
|
||||
@@ -125,7 +126,7 @@ class FeedGenTestCase(TestCase):
|
||||
|
||||
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
|
||||
delta_a = timedelta(minutes = 2)
|
||||
dates_a = list(date_gen(delta = delta_a, n = 20))
|
||||
dates_a = list(date_gen(delta = delta_a, count = 20))
|
||||
dates_a.append("DONE")
|
||||
|
||||
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
|
||||
@@ -133,7 +134,7 @@ class FeedGenTestCase(TestCase):
|
||||
|
||||
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
|
||||
delta_b = timedelta(minutes = 1)
|
||||
dates_b = list(date_gen(delta = delta_b, n = 10))
|
||||
dates_b = list(date_gen(delta = delta_b, count = 10))
|
||||
dates_b.append("DONE")
|
||||
|
||||
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
|
||||
@@ -152,13 +153,65 @@ class FeedGenTestCase(TestCase):
|
||||
|
||||
sequential = chain(iter(events_a), iter(events_b))
|
||||
self.run_FeedGen(sequential, expected, source_ids)
|
||||
|
||||
def test_full_feed_layer(self):
|
||||
|
||||
filter = [1,2]
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
# Set up source d. This should produce no events because the
|
||||
# internal sids don't match the filter.
|
||||
args_d = tuple()
|
||||
kwargs_d = {'sids' : [3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
|
||||
sources = (SpecificEquityTrades,) * 4
|
||||
source_args = (args_a, args_b, args_c, args_d)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
|
||||
|
||||
# Generate our expected source_ids.
|
||||
zip_args = zip(source_args, source_kwargs)
|
||||
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
|
||||
for args, kwargs in zip_args]
|
||||
|
||||
# Pipe our sources into feed.
|
||||
feed_out = PreTransformLayer(sources, source_args, source_kwargs)
|
||||
|
||||
# Read all the values from feed and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(feed_out)
|
||||
copy = to_list[:]
|
||||
for e in to_list:
|
||||
# All events should match one of our expected source_ids.
|
||||
assert e.source_id in expected_ids
|
||||
# But none of them should match source_d.
|
||||
assert e.source_id != hash_args(*args_d, **kwargs_d)
|
||||
|
||||
expected = sorted(copy, compare_by_dt_source_id)
|
||||
assert to_list == expected
|
||||
|
||||
def test_with_specific_equity(self):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def mock_data_unframe(source_id, dt, type):
|
||||
event = ndict()
|
||||
event.source_id = source_id
|
||||
@@ -182,7 +235,3 @@ def compare_by_dt_source_id(x,y):
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from itertools import izip, izip_longest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen
|
||||
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
|
||||
assert_trade_protocol, mock_raw_event
|
||||
|
||||
import zipline.protocol as zp
|
||||
@@ -107,7 +107,7 @@ class TestMongoDataGenerator(TestCase):
|
||||
for field in iter(['sid', 'dt', 'price', 'volume']):
|
||||
assert db[field] == expected[field]
|
||||
|
||||
# Expected output of stringify_args:
|
||||
# Expected output of hash_args:
|
||||
assert db['source_id'] == \
|
||||
'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc'
|
||||
|
||||
|
||||
+108
-26
@@ -1,44 +1,126 @@
|
||||
import random
|
||||
from itertools import chain, repeat, cycle, ifilter
|
||||
from itertools import chain, repeat, cycle, ifilter, izip
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trade, create_trade
|
||||
from zipline.gens.utils import date_gen
|
||||
from zipline.utils.factory import create_trade
|
||||
from zipline.gens.utils import hash_args, mock_done
|
||||
|
||||
def mock_prices(n, rand = False):
|
||||
"""Utility to generate a set of prices. By default
|
||||
cycles through values from 0.0 to 10.0 n times. Optional
|
||||
flag to give random values between 0.0 and 10.0"""
|
||||
def date_gen(start = datetime(2012, 6, 6, 0),
|
||||
delta = timedelta(minutes = 1),
|
||||
count = 100):
|
||||
"""
|
||||
Utility to generate a stream of dates.
|
||||
"""
|
||||
return (start + (i * delta) for i in xrange(count))
|
||||
|
||||
def mock_prices(count, rand = False):
|
||||
"""
|
||||
Utility to generate a stream of mock prices. By default
|
||||
cycles through values from 0.0 to 10.0, n times. Optional
|
||||
flag to give random values between 0.0 and 10.0
|
||||
"""
|
||||
|
||||
if rand:
|
||||
return (random.uniform(0.0, 10.0) for i in xrange(n))
|
||||
return (random.uniform(0.0, 10.0) for i in xrange(count))
|
||||
else:
|
||||
return (float(i % 11) for i in xrange(1,n+1))
|
||||
return (float(i % 11) for i in xrange(1,count+1))
|
||||
|
||||
def mock_volumes(n, rand = False):
|
||||
"""Does the same as mock_prices. Different function name
|
||||
for readability."""
|
||||
return mock_prices(n, rand)
|
||||
|
||||
def SpecificEquityTrades(n = 500, sids = [1, 2], event_list = None, filter = None):
|
||||
"""Returns the first n events of event_list if specified.
|
||||
Otherwise generates a sensible stream of events."""
|
||||
def mock_volumes(count, rand = False):
|
||||
"""
|
||||
Utility to generate a set of volumes. By default cycles
|
||||
through values from 100 to 1000, incrementing by 50. Optional
|
||||
flag to give random values between 100 and 1000.
|
||||
"""
|
||||
if rand:
|
||||
return (random.randrange(100, 1000) for i in xrange(count))
|
||||
else:
|
||||
return ((i * 50)%900 + 100 for i in xrange(count))
|
||||
|
||||
def fuzzy_dates(count = 500):
|
||||
"""
|
||||
Add +-10 seconds to each event from a date_gen. Note that this
|
||||
still guarantees sorting, since the default on date_gen is minute
|
||||
separation of events.
|
||||
"""
|
||||
for date in date_gen(count = count):
|
||||
yield date + timedelta(seconds = random.randint(-10, 10))
|
||||
|
||||
def SpecificEquityTrades(*args, **config):
|
||||
"""
|
||||
Yields all events in event_list that match the given sid_filter.
|
||||
If no event_list is specified, generates an internal stream of events
|
||||
to filter. Returns all events if filter is None.
|
||||
"""
|
||||
# We shouldn't get any positional arguments.
|
||||
assert args == ()
|
||||
|
||||
# Unpack config dictionary with default values.
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1, 2])
|
||||
start = config.get('start', datetime(2012, 6, 6, 0))
|
||||
delta = config.get('delta', timedelta(minutes = 1))
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
event_list = config.get('event_list')
|
||||
filter = config.get('filter')
|
||||
|
||||
arg_string = hash_args(*args, **config)
|
||||
namestring = "SpecificEquityTrades" + arg_string
|
||||
# If we have an event_list, ignore the other arguments and use the list.
|
||||
# TODO: still append our namestring?
|
||||
if event_list:
|
||||
unfiltered = (event for event in event_list)
|
||||
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(n = n)
|
||||
prices = mock_prices(n)
|
||||
volumes = mock_volumes(n)
|
||||
sids = cycle(iter(sids))
|
||||
|
||||
dates = date_gen(count = count, start = start, delta = delta)
|
||||
prices = mock_prices(count)
|
||||
volumes = mock_volumes(count)
|
||||
sids = cycle(sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
unfiltered = (create_trade(*args) for args in arg_gen)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = namestring)
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't match the filter.
|
||||
if filter:
|
||||
filtered = ifilter(lambda event: event.sid in filter)
|
||||
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
|
||||
|
||||
# Otherwise just use all events.
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Add a done message to the end of the stream. For a live
|
||||
# datasource this would be handled by the containing Component.
|
||||
out = chain(filtered, [mock_done(namestring)])
|
||||
return out
|
||||
|
||||
def RandomEquityTrades(*args, **config):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1,2])
|
||||
filter = config.get('filter')
|
||||
|
||||
dates = fuzzy_dates(count)
|
||||
prices = mock_prices(count, rand = True)
|
||||
volumes = mock_volumes(count, rand = True)
|
||||
sids = cycle(sids)
|
||||
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
unfiltered = (create_trade(*args) for args in arg_gen)
|
||||
|
||||
if filter:
|
||||
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
|
||||
else:
|
||||
filtered = unfiltered
|
||||
return filtered
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# import nose.tools; nose.tools.set_trace()
|
||||
# trades = SpecificEquityTrades(filter = [1])
|
||||
|
||||
+45
-44
@@ -14,26 +14,24 @@ from numbers import Number
|
||||
from itertools import izip
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, date_gen
|
||||
from zipline.gens.utils import assert_feed_unframe_protocol, assert_transform_protocol
|
||||
from zipline.gens.tradegens import date_gen
|
||||
from zipline.gens.utils import assert_feed_unframe_protocol, \
|
||||
assert_transform_protocol, hash_args
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def PassthroughTransformGen(stream_in):
|
||||
"""Trivial transform for event forwarding."""
|
||||
class Passthrough(object):
|
||||
"""
|
||||
Trivial function for forwarding events.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# hash_args with no arguments is the same as:
|
||||
# hasher = hashlib.md5()
|
||||
# hasher.update(":");
|
||||
# hashlib.md5.digest().
|
||||
|
||||
namestring = "Passthrough" + hash_args()
|
||||
|
||||
for message in stream_in:
|
||||
assert_feed_unframe_protocol(message)
|
||||
out_value = message
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
def update(self, event):
|
||||
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
|
||||
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
|
||||
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
|
||||
return event
|
||||
|
||||
def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
"""
|
||||
@@ -43,6 +41,10 @@ def FunctionalTransformGen(stream_in, fun, *args, **kwargs):
|
||||
"""
|
||||
|
||||
# TODO: Distinguish between functions and classes in hash_args.
|
||||
# As implemented we will get assertion errors if a function and
|
||||
# stateful class have the same name, which may or may not be
|
||||
# what we want.
|
||||
|
||||
namestring = fun.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
for message in stream_in:
|
||||
@@ -70,13 +72,6 @@ def StatefulTransformGen(stream_in, tnfm_class, *args, **kwargs):
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
|
||||
def MovingAverageTransformGen(stream_in, days, fields):
|
||||
"""
|
||||
Generator that uses the MovingAverage state class to calculate
|
||||
a moving average for all stocks over a specified number of days.
|
||||
"""
|
||||
return StatefulTransformGen(stream_in, MovingAverage, timedelta(days=days), fields)
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to EventWindows
|
||||
@@ -91,7 +86,7 @@ class MovingAverage(object):
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
|
||||
def create_window(self):
|
||||
"""Factory method for self.sid_windows."""
|
||||
return EventWindow(self.delta, self.fields)
|
||||
@@ -104,13 +99,18 @@ class MovingAverage(object):
|
||||
|
||||
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
|
||||
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
|
||||
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
|
||||
|
||||
output = ndict({'sid': event.sid, 'dt': event.dt})
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
averages = window.get_averages()
|
||||
|
||||
return window.get_averages()
|
||||
# Return the calculated averages along with
|
||||
output.merge(averages)
|
||||
return output
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ class EventWindow(object):
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
|
||||
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
|
||||
# popleft removes and returns ticks[0]
|
||||
popped = self.ticks.popleft()
|
||||
@@ -168,6 +168,7 @@ class EventWindow(object):
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
# out.ticks = len(self.ticks)
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
return out
|
||||
@@ -186,26 +187,26 @@ class EventWindow(object):
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# if __name__ == "__main__":
|
||||
|
||||
def make_event(**kwargs):
|
||||
e = ndict()
|
||||
for key, value in kwargs.iteritems():
|
||||
e[key] = value
|
||||
return e
|
||||
# def make_event(**kwargs):
|
||||
# e = ndict()
|
||||
# for key, value in kwargs.iteritems():
|
||||
# e[key] = value
|
||||
# return e
|
||||
|
||||
dates = date_gen(delta = timedelta(hours = 12))
|
||||
events = (
|
||||
make_event(
|
||||
sid = 'foo', price = random.random(),
|
||||
dt = date,
|
||||
type = zp.DATASOURCE_TYPE.TRADE,
|
||||
source_id = 'ds',
|
||||
vol = i
|
||||
)
|
||||
for date, i in izip(dates, xrange(100))
|
||||
)
|
||||
# dates = date_gen(delta = timedelta(hours = 12))
|
||||
# events = (
|
||||
# make_event(
|
||||
# sid = 'foo', price = random.random(),
|
||||
# dt = date,
|
||||
# type = zp.DATASOURCE_TYPE.TRADE,
|
||||
# source_id = 'ds',
|
||||
# vol = i
|
||||
# )
|
||||
# for date, i in izip(dates, xrange(100))
|
||||
# )
|
||||
|
||||
gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
|
||||
# gen = MovingAverageTransformGen(events, 1, ['price', 'vol'])
|
||||
|
||||
|
||||
|
||||
+23
-6
@@ -16,16 +16,28 @@ def mock_raw_event(sid, dt):
|
||||
}
|
||||
return event
|
||||
|
||||
def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100):
|
||||
return (start + i * delta for i in xrange(n))
|
||||
|
||||
def mock_done(source_id):
|
||||
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
|
||||
|
||||
def alternate(g1, g2):
|
||||
"""Specialized version of roundrobin for just 2 generators."""
|
||||
for e1, e2 in izip_longest(g1, g2):
|
||||
if e1 != None:
|
||||
yield e1
|
||||
if e2 != None:
|
||||
yield e2
|
||||
|
||||
def roundrobin(*args):
|
||||
"""
|
||||
Takes N generators, pulling one element off each until all inputs
|
||||
are empty.
|
||||
"""
|
||||
for elem_tuple in izip_longest(*args):
|
||||
for value in elem_tuple:
|
||||
if value != None:
|
||||
yield value
|
||||
|
||||
|
||||
def hash_args(*args, **kwargs):
|
||||
"""Define a unique string for any set of representable args."""
|
||||
arg_string = '_'.join([str(arg) for arg in args])
|
||||
@@ -73,14 +85,19 @@ def assert_feed_protocol(event):
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_feed_unframe_protocol(event):
|
||||
"""Same as above."""
|
||||
assert isinstance(event, ndict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type in DATASOURCE_TYPE
|
||||
assert event.has_key('dt')
|
||||
|
||||
|
||||
def assert_transform_protocol(event):
|
||||
pass
|
||||
"""Transforms should return an ndict to be merged by MergeGen."""
|
||||
assert isinstance(event, ndict)
|
||||
|
||||
def assert_merge_protocol(tnfm_ids, message):
|
||||
"""Merge should output an ndict with a field for each id in its transform set."""
|
||||
assert isinstance(message, ndict)
|
||||
assert set(tnfm_ids) == set(message.keys())
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import zmq
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def gen_from_zmq(poller, unframe):
|
||||
"""
|
||||
A generator that takes an initialized zmq poller and yields
|
||||
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
|
||||
"""
|
||||
while True:
|
||||
message = poller.recv()
|
||||
if message = zp.CONTROL_PROTOCOL.DONE:
|
||||
yield "DONE"
|
||||
break
|
||||
else:
|
||||
yield unframe(message)
|
||||
@@ -69,9 +69,9 @@ def create_trading_environment(year=2006):
|
||||
|
||||
return trading_environment
|
||||
|
||||
def create_trade(sid, price, amount, datetime):
|
||||
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
|
||||
row = zp.ndict({
|
||||
'source_id' : "test_factory",
|
||||
'source_id' : source_id,
|
||||
'type' : zp.DATASOURCE_TYPE.TRADE,
|
||||
'sid' : sid,
|
||||
'dt' : datetime,
|
||||
|
||||
Reference in New Issue
Block a user