api changes and refactor of sorting tests

This commit is contained in:
scottsanderson
2012-08-06 11:05:25 -04:00
parent b67cbb2aab
commit 4655e643a4
8 changed files with 397 additions and 278 deletions
-233
View File
@@ -1,233 +0,0 @@
from unittest2 import TestCase
from itertools import cycle, chain
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.sort import \
date_sort, \
ready, \
done, \
queue_is_ready,\
queue_is_done
from zipline.gens.utils import hash_args, alternate
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import date_sorted_sources
import zipline.protocol as zp
class HelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_individual_queue_logic(self):
queue = deque()
# Empty queues are neither done nor ready.
assert not queue_is_ready(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_ready(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_ready(queue)
# Checking done when we have a message after done will trip an assert.
self.assertRaises(AssertionError, queue_is_done, queue)
queue.pop()
assert queue_is_ready(queue)
assert queue_is_done(queue)
def test_pop_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not ready(sources)
assert not done(sources)
# All sources must have a message to be ready/done
sources['a'].append(to_dt("datetime"))
assert not ready(sources)
assert not done(sources)
sources['a'].pop()
for id in ids:
sources[id].append(to_dt("datetime"))
assert ready(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert ready(sources)
self.assertRaises(AssertionError, done, sources)
for id in ids:
sources[id].pop()
assert ready(sources)
assert done(sources)
class DateSortTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def run_date_sort(self, events, expected, source_ids):
"""
Take a list of events, their source_ids, and an expected sorting.
Assert that date_sort's output agrees with expected.
"""
sort_gen = date_sort(events, source_ids)
l = list(sort_gen)
assert l == expected
def test_single_source(self):
source_ids = ['a']
# 100 events, increasing by a minute at a time.
type = zp.DATASOURCE_TYPE.TRADE
dates = list(date_gen(count = 100))
dates.append("DONE")
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
# Turn event_args into proper events.
events = [mock_data_unframe(*args) for args in event_args]
# We don't expected Feed to yield the last event.
expected = events[:-1]
event_gen = (e for e in events)
self.run_date_sort(event_gen, expected, source_ids)
def test_multi_source(self):
source_ids = ['a', 'b']
type = zp.DATASOURCE_TYPE.TRADE
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
delta_a = timedelta(minutes = 2)
dates_a = list(date_gen(delta = delta_a, count = 20))
dates_a.append("DONE")
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
events_a = [mock_data_unframe(*args) for args in events_a_args]
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
delta_b = timedelta(minutes = 1)
dates_b = list(date_gen(delta = delta_b, count = 10))
dates_b.append("DONE")
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
events_b = [mock_data_unframe(*args) for args in events_b_args]
# The expected output is all non-DONE events in both a and b,
# sorted first by dt and then by source_id.
non_dones = events_a[:-1] + events_b[:-1]
expected = sorted(non_dones, compare_by_dt_source_id)
# Alternating between a and b.
interleaved = alternate(iter(events_a), iter(events_b))
self.run_date_sort(interleaved, expected, source_ids)
# All of a, then all of b.
sequential = chain(iter(events_a), iter(events_b))
self.run_date_sort(sequential, expected, source_ids)
def test_sorted_sources(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
kwargs_a = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(hours = 1),
'filter' : filter
}
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
# Set up source d. This should produce no events because the
# internal sids don't match the filter.
args_d = tuple()
kwargs_d = {'sids' : [3,4],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
sources = (SpecificEquityTrades,) * 4
source_args = (args_a, args_b, args_c, args_d)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
# Generate our expected source_ids.
zip_args = zip(source_args, source_kwargs)
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
for args, kwargs in zip_args]
# Pipe our sources into sort.
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
# Read all the values from sort and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(sort_out)
copy = to_list[:]
for e in to_list:
# All events should match one of our expected source_ids.
assert e.source_id in expected_ids
# But none of them should match source_d.
assert e.source_id != hash_args(*args_d, **kwargs_d)
expected = sorted(copy, compare_by_dt_source_id)
assert to_list == expected
def mock_data_unframe(source_id, dt, type):
event = ndict()
event.source_id = source_id
event.dt = dt
event.type = type
return event
def to_dt(val):
return ndict({'dt': val})
def compare_by_dt_source_id(x,y):
if x.dt < y.dt:
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
+257
View File
@@ -0,0 +1,257 @@
import pytz
from unittest2 import TestCase
from itertools import cycle, chain, izip, izip_longest
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.sort import \
date_sort, \
ready, \
done, \
queue_is_ready,\
queue_is_done
from zipline.gens.utils import hash_args, alternate, done_message
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import date_sorted_sources
import zipline.protocol as zp
class HelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_individual_queue_logic(self):
queue = deque()
# Empty queues are neither done nor ready.
assert not queue_is_ready(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_ready(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_ready(queue)
# Checking done when we have a message after done will trip an assert.
self.assertRaises(AssertionError, queue_is_done, queue)
queue.pop()
assert queue_is_ready(queue)
assert queue_is_done(queue)
def test_pop_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not ready(sources)
assert not done(sources)
# All sources must have a message to be ready/done
sources['a'].append(to_dt("datetime"))
assert not ready(sources)
assert not done(sources)
sources['a'].pop()
for id in ids:
sources[id].append(to_dt("datetime"))
assert ready(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert ready(sources)
self.assertRaises(AssertionError, done, sources)
for id in ids:
sources[id].pop()
assert ready(sources)
assert done(sources)
class DateSortTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def run_date_sort(self, event_stream, expected, source_ids):
"""
Take a list of events, their source_ids, and an expected sorting.
Assert that date_sort's output agrees with expected.
"""
sort_out = date_sort(event_stream, source_ids)
for m1, m2 in izip_longest(sort_out, expected):
assert m1 == m2
def test_single_source(self):
# Just using the built-in defaults. See
# zipline/gens/tradegens.py
source = SpecificEquityTrades()
expected = list(source)
source.rewind()
# The raw source doesn't handle done messaging, so we need to
# append a done message for sort to work properly.
with_done = chain(source, [done_message(source.get_hash())])
self.run_date_sort(with_done, expected, [source.get_hash()])
def test_multi_source(self):
filter = [2,3]
args_a = tuple()
kwargs_a = {
'count' : 100,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
args_b = tuple()
kwargs_b = {
'count' : 100,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
all_events = list(chain(source_a, source_b))
# The expected output is all events, sorted by dt with
# source_id as a tiebreaker.
expected = sorted(all_events, comp)
source_ids = [source_a.get_hash(), source_b.get_hash()]
# Generating the events list consumes the sources. Rewind them
# for testing.
source_a.rewind()
source_b.rewind()
# Append a done message to each source.
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
interleaved = alternate(with_done_a, with_done_b)
# Test sort with alternating messages from source_a and
# source_b.
self.run_date_sort(interleaved, expected, source_ids)
source_a.rewind()
source_b.rewind()
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
sequential = chain(with_done_a, with_done_b)
# Test sort with all messages from a, followed by all messages
# from b.
self.run_date_sort(sequential, expected, source_ids)
def test_sort_composite(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
kwargs_a = {
'count' : 100,
'sids' : [1],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(hours = 1),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {
'count' : 50,
'sids' : [2],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(days = 1),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
#Set up source c. One minute between events.
args_c = tuple()
kwargs_c = {
'count' : 150,
'sids' : [1,2],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
source_c = SpecificEquityTrades(*args_c, **kwargs_c)
# Set up source d. This should produce no events because the
# internal sids don't match the filter.
args_d = tuple()
kwargs_d = {
'count' : 50,
'sids' : [3],
'start' : datetime(2012,6,6,0),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
source_d = SpecificEquityTrades(*args_d, **kwargs_d)
sources = [source_a, source_b, source_c, source_d]
hashes = [source.get_hash() for source in sources]
sort_out = date_sorted_sources(*sources)
# Read all the values from sort and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(sort_out)
copy = to_list[:]
# We should have 300 events (100 from a, 150 from b, 50 from c)
assert len(to_list) == 300
for e in to_list:
# All events should match one of our expected source_ids.
assert e.source_id in hashes
# But none of them should match source_d.
assert e.source_id != source_d.get_hash()
# The events should be sorted by dt, with source_id as tiebreaker.
expected = sorted(copy, comp)
assert to_list == expected
def compare_by_dt_source_id(x,y):
if x.dt < y.dt:
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
#Alias for ease of use
comp = compare_by_dt_source_id
+36 -15
View File
@@ -13,8 +13,9 @@ TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs'])
def date_sorted_sources(*sources):
"""
Takes an iterable of SortBundles, generating namestrings and initialized datasources
for each before piping them into a date_sort.
Takes an iterable of SortBundles, generating namestrings and
initialized datasources for each before piping them into a
date_sort.
"""
for source in sources:
@@ -28,21 +29,21 @@ def date_sorted_sources(*sources):
# one element at a time from each.
stream_in = roundrobin(sources, names)
# Guarantee the flat stream will be sorted by date, using source_id as
# tie-breaker, which is fully deterministic (given deterministic string
# representation for all args/kwargs)
# Guarantee the flat stream will be sorted by date, using
# source_id as tie-breaker, which is fully deterministic (given
# deterministic string representation for all args/kwargs)
return date_sort(stream_in, names)
def merged_transforms(sorted_stream, *transforms):
"""
A generator that takes the expected output of a date_sort, pipes it
through a given set of transforms, and runs the results throught a
merge to output a unified stream. tnfms should be a list of
pointers to generator functions. tnfm_args should be a list of
tuples, representing the arguments to be passed to each transform.
tnfm_kwargs should be a list of dictionaries representing keyword
arguments to each transform.
A generator that takes the expected output of a date_sort, pipes
it through a given set of transforms, and runs the results
through a merge to output a unified stream. tnfms should be a
list of pointers to generator functions. tnfm_args should be a
list of tuples, representing the arguments to be passed to each
transform. tnfm_kwargs should be a list of dictionaries
representing keyword arguments to each transform.
"""
for transform in transforms:
assert isinstance(transform, StatefulTransform)
@@ -62,15 +63,35 @@ def merged_transforms(sorted_stream, *transforms):
# Roundrobin the outputs of our transforms to create a single flat
# stream.
to_merge = roundrobin(tnfm_gens, namestrings)
# Pipe the stream into merge.
merged = merge(to_merge, namestrings)
# Return the merged events.
return merged
def zipline(sources, transforms, endpoint):
assert isinstance(sources, (list, tuple))
def sequential_transforms(stream_in, *transforms):
"""
Apply each transform in transforms sequentially to each event in stream_in.
Each transform application will add a new entry indexed to the transform's
hash string.
"""
assert isinstance(transforms, (list, tuple))
for tnfm in transforms:
tnfm.forward_all = False
tnfm.update_in_place = False
tnfm.append_value = True
# Recursively apply all transforms to the stream.
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
transforms,
stream_in)
return stream_out
+55 -15
View File
@@ -1,14 +1,16 @@
import pytz
import time
from time import sleep
from pprint import pprint as pp
from datetime import datetime, timedelta
from itertools import izip
from zipline.utils.factory import create_trading_environment
from zipline.test_algorithms import TestAlgorithm
from zipline.gens.composites import SourceBundle, TransformBundle, \
date_sorted_sources, merged_transforms
date_sorted_sources, merged_transforms, sequential_transforms
from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
@@ -18,43 +20,81 @@ import zipline.protocol as zp
if __name__ == "__main__":
filter = [2,3]
#Set up source a. One minute between events.
#Set up source a. Six minutes between events.
args_a = tuple()
kwargs_a = {
'count' : 325,
'count' : 1000,
'sids' : [1,2,3],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 6),
'delta' : timedelta(minutes = 6),
'filter' : filter
}
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
#Set up source b. Five minutes between events.
args_b = tuple()
kwargs_b = {
'count' : 7500,
'count' : 1000,
'sids' : [2,3,4],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 5),
'filter' : filter
}
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
#Set up source c. Three minutes between events.
source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b)
sorted = date_sorted_sources(source_a, source_b)
sorted_prime = date_sorted_sources(
source_a_prime,
source_b_prime
)
passthrough = StatefulTransform(Passthrough)
mavg_price = StatefulTransform(MovingAverage, timedelta(minutes = 20), ['price'])
mavg_price = StatefulTransform(
MovingAverage,
timedelta(minutes = 20),
['price']
)
passthrough_prime = StatefulTransform(Passthrough)
mavg_price_prime = StatefulTransform(
MovingAverage,
timedelta(minutes = 20),
['price']
)
merged = merged_transforms(sorted, passthrough, mavg_price)
start = time.time()
for message in merged:
assert 1 + 1 == 2
stop = time.time()
merge_time = stop - start
print "Merge time: %s" % str(merge_time)
sequential = sequential_transforms(
sorted_prime,
passthrough_prime,
mavg_price_prime
)
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
environment = create_trading_environment(year = 2012)
style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
start = time.time()
for message in sequential:
assert 1 + 1 == 2
stop = time.time()
seq_time = stop - start
print "Sequential time: %s" % str(seq_time)
print "Merge/Seq: %s" % (str(merge_time/seq_time))
trading_client = tsc(algo, environment, style)
# merged = merged_transforms(sorted, passthrough, mavg_price)
for message in trading_client.simulate(merged):
pp(message)
# algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
# environment = create_trading_environment(year = 2012)
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
# trading_client = tsc(algo, environment, style)
# for message in trading_client.simulate(merged):
# pp(message)
+3 -4
View File
@@ -11,8 +11,8 @@ from itertools import repeat
def merge(stream_in, tnfm_ids):
"""
A generator that takes a generator and a list of source_ids. We
maintain an internal queue for each id in source_ids. Once we
A generator that takes a generator and a list of transform ids. We
maintain an internal queue for each id in tnfm_ids. Once we
have a message from every queue, we pop an event from each queue
and merge them together into an event. We raise an error if we
do not receive the same number of events from all sources.
@@ -54,9 +54,8 @@ def merge(stream_in, tnfm_ids):
yield done_message('Merge')
def merge_one(sources):
dict_primer = zip(sources.keys(), repeat(None))
event_fields = ndict()
event_fields = ndict()
for key, queue in sources.iteritems():
# Add transform value to the transforms dict.
+1 -2
View File
@@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids):
have messages pending from all sources, we pull the earliest
message and yield it.
"""
assert isinstance(source_ids, (list, tuple))
# Set up an internal queue for each expected source.
@@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids):
message = pop_oldest(sources)
assert_sort_protocol(message)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
+37 -9
View File
@@ -56,9 +56,12 @@ class StatefulTransform(object):
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
self.append_value = tnfm_class.__dict__.get('APPENDER', False)
# You can't be both a forwarded and an updater.
assert not all([self.forward_all, self.update_in_place])
# You only one special behavior mode can be set.
assert sum(map(int, [self.forward_all,
self.update_in_place,
self.append_value])) <= 1
# Create an instance of our transform class.
self.state = tnfm_class(*args, **kwargs)
@@ -75,11 +78,15 @@ class StatefulTransform(object):
def _gen(self, stream_in):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
for message in stream_in:
# allow upstream generators to yield None to avoid
# blocking.
if message == None:
continue
#TODO: refactor this to avoid unnecessary copying.
assert_sort_unframe_protocol(message)
message_copy = deepcopy(message)
@@ -87,22 +94,43 @@ class StatefulTransform(object):
# Same shared pointer issue here as above.
tnfm_value = self.state.update(deepcopy(message_copy))
# If we want to keep all original values, plus append tnfm_id
# and tnfm_value. Used for Passthrough.
# FORWARDER flag means we want to keep all original
# values, plus append tnfm_id and tnfm_value. Used for
# preserving the original event fields when our output
# will be fed into a merge.
if self.forward_all:
out_message = message_copy
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
yield out_message
# Our expectation is that the transform simply updated the
# message it was passed. Useful for chaining together
# multiple transforms, e.g. TransactionSimulator/PerformanceTracker.
# UPDATER flag should be used for transforms that
# side-effectfully modify the event they are passed.
# Updated messages are passed along exactly as they are
# returned to use by our state class. Useful for chaining
# specific transforms that won't be fed to a merge. (See
# the implementation of TradeSimulationClient for example
# usage of this flag with PerformanceTracker and
# TransactionSimulator.
elif self.update_in_place:
yield tnfm_value
# APPENDER flag should be used to add a single new
# key-value pair to the event. The new key is this
# transform's namestring, and it's value is the value
# returned by state.update(event). This is almost
# identical to the behavior of FORWARDER, except we
# compress the two calculated values (tnfm_id, and
# tnfm_value) into a single field.
elif self.append_value:
out_message = message_copy
out_message[self.namestring] = tnfm_value
yield out_message
# Otherwise send tnfm_id, tnfm_value, and the message
# date. Useful for transforms being piped to a merge.
# If no flags are set, we create a new message containing
# just the tnfm_id, the event's datetime, and the
# calculated tnfm_value. This is the default behavior for
# a transform being fed into a merge.
else:
out_message = ndict()
out_message.tnfm_id = self.namestring
+8
View File
@@ -66,6 +66,14 @@ def hash_args(*args, **kwargs):
hasher.update(combined)
return hasher.hexdigest()
def sum_true(bool_iterable):
"""
Takes an iterable of boolean values and returns the number of
those values that are True.
"""
return sum(map(int, bool_iterable))
def assert_datasource_protocol(event):
"""Assert that an event meets the protocol for datasource outputs."""