mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 12:34:52 +08:00
api changes and refactor of sorting tests
This commit is contained in:
@@ -1,233 +0,0 @@
|
||||
from unittest2 import TestCase
|
||||
from itertools import cycle, chain
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.sort import \
|
||||
date_sort, \
|
||||
ready, \
|
||||
done, \
|
||||
queue_is_ready,\
|
||||
queue_is_done
|
||||
from zipline.gens.utils import hash_args, alternate
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
class HelperTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_individual_queue_logic(self):
|
||||
queue = deque()
|
||||
# Empty queues are neither done nor ready.
|
||||
assert not queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
queue.append(to_dt('foo'))
|
||||
assert queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
queue.appendleft(to_dt('DONE'))
|
||||
assert queue_is_ready(queue)
|
||||
|
||||
# Checking done when we have a message after done will trip an assert.
|
||||
self.assertRaises(AssertionError, queue_is_done, queue)
|
||||
|
||||
queue.pop()
|
||||
assert queue_is_ready(queue)
|
||||
assert queue_is_done(queue)
|
||||
|
||||
def test_pop_logic(self):
|
||||
sources = {}
|
||||
ids = ['a', 'b', 'c']
|
||||
for id in ids:
|
||||
sources[id] = deque()
|
||||
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
# All sources must have a message to be ready/done
|
||||
sources['a'].append(to_dt("datetime"))
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
sources['a'].pop()
|
||||
|
||||
for id in ids:
|
||||
sources[id].append(to_dt("datetime"))
|
||||
|
||||
assert ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].appendleft(to_dt("DONE"))
|
||||
|
||||
# ["DONE", message] will trip an assert in queue_is_done.
|
||||
assert ready(sources)
|
||||
self.assertRaises(AssertionError, done, sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].pop()
|
||||
|
||||
assert ready(sources)
|
||||
assert done(sources)
|
||||
|
||||
class DateSortTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def run_date_sort(self, events, expected, source_ids):
|
||||
"""
|
||||
Take a list of events, their source_ids, and an expected sorting.
|
||||
Assert that date_sort's output agrees with expected.
|
||||
"""
|
||||
sort_gen = date_sort(events, source_ids)
|
||||
l = list(sort_gen)
|
||||
assert l == expected
|
||||
|
||||
def test_single_source(self):
|
||||
source_ids = ['a']
|
||||
# 100 events, increasing by a minute at a time.
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
dates = list(date_gen(count = 100))
|
||||
dates.append("DONE")
|
||||
|
||||
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
|
||||
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
|
||||
|
||||
# Turn event_args into proper events.
|
||||
events = [mock_data_unframe(*args) for args in event_args]
|
||||
|
||||
# We don't expected Feed to yield the last event.
|
||||
expected = events[:-1]
|
||||
|
||||
event_gen = (e for e in events)
|
||||
|
||||
self.run_date_sort(event_gen, expected, source_ids)
|
||||
|
||||
def test_multi_source(self):
|
||||
source_ids = ['a', 'b']
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
|
||||
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
|
||||
delta_a = timedelta(minutes = 2)
|
||||
dates_a = list(date_gen(delta = delta_a, count = 20))
|
||||
dates_a.append("DONE")
|
||||
|
||||
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
|
||||
events_a = [mock_data_unframe(*args) for args in events_a_args]
|
||||
|
||||
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
|
||||
delta_b = timedelta(minutes = 1)
|
||||
dates_b = list(date_gen(delta = delta_b, count = 10))
|
||||
dates_b.append("DONE")
|
||||
|
||||
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
|
||||
events_b = [mock_data_unframe(*args) for args in events_b_args]
|
||||
|
||||
# The expected output is all non-DONE events in both a and b,
|
||||
# sorted first by dt and then by source_id.
|
||||
non_dones = events_a[:-1] + events_b[:-1]
|
||||
expected = sorted(non_dones, compare_by_dt_source_id)
|
||||
|
||||
# Alternating between a and b.
|
||||
interleaved = alternate(iter(events_a), iter(events_b))
|
||||
self.run_date_sort(interleaved, expected, source_ids)
|
||||
|
||||
# All of a, then all of b.
|
||||
|
||||
sequential = chain(iter(events_a), iter(events_b))
|
||||
self.run_date_sort(sequential, expected, source_ids)
|
||||
|
||||
def test_sorted_sources(self):
|
||||
|
||||
filter = [1,2]
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
# Set up source d. This should produce no events because the
|
||||
# internal sids don't match the filter.
|
||||
args_d = tuple()
|
||||
kwargs_d = {'sids' : [3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
|
||||
sources = (SpecificEquityTrades,) * 4
|
||||
source_args = (args_a, args_b, args_c, args_d)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
|
||||
|
||||
# Generate our expected source_ids.
|
||||
zip_args = zip(source_args, source_kwargs)
|
||||
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
|
||||
for args, kwargs in zip_args]
|
||||
|
||||
# Pipe our sources into sort.
|
||||
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
|
||||
|
||||
# Read all the values from sort and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(sort_out)
|
||||
copy = to_list[:]
|
||||
for e in to_list:
|
||||
# All events should match one of our expected source_ids.
|
||||
assert e.source_id in expected_ids
|
||||
# But none of them should match source_d.
|
||||
assert e.source_id != hash_args(*args_d, **kwargs_d)
|
||||
|
||||
expected = sorted(copy, compare_by_dt_source_id)
|
||||
assert to_list == expected
|
||||
|
||||
def mock_data_unframe(source_id, dt, type):
|
||||
event = ndict()
|
||||
event.source_id = source_id
|
||||
event.dt = dt
|
||||
event.type = type
|
||||
return event
|
||||
|
||||
def to_dt(val):
|
||||
return ndict({'dt': val})
|
||||
|
||||
def compare_by_dt_source_id(x,y):
|
||||
if x.dt < y.dt:
|
||||
return -1
|
||||
elif x.dt > y.dt:
|
||||
return 1
|
||||
|
||||
elif x.source_id < y.source_id:
|
||||
return -1
|
||||
elif x.source_id > y.source_id:
|
||||
return 1
|
||||
|
||||
else:
|
||||
return 0
|
||||
@@ -0,0 +1,257 @@
|
||||
import pytz
|
||||
|
||||
from unittest2 import TestCase
|
||||
from itertools import cycle, chain, izip, izip_longest
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.sort import \
|
||||
date_sort, \
|
||||
ready, \
|
||||
done, \
|
||||
queue_is_ready,\
|
||||
queue_is_done
|
||||
from zipline.gens.utils import hash_args, alternate, done_message
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
class HelperTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_individual_queue_logic(self):
|
||||
queue = deque()
|
||||
# Empty queues are neither done nor ready.
|
||||
assert not queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
queue.append(to_dt('foo'))
|
||||
assert queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
queue.appendleft(to_dt('DONE'))
|
||||
assert queue_is_ready(queue)
|
||||
|
||||
# Checking done when we have a message after done will trip an assert.
|
||||
self.assertRaises(AssertionError, queue_is_done, queue)
|
||||
|
||||
queue.pop()
|
||||
assert queue_is_ready(queue)
|
||||
assert queue_is_done(queue)
|
||||
|
||||
def test_pop_logic(self):
|
||||
sources = {}
|
||||
ids = ['a', 'b', 'c']
|
||||
for id in ids:
|
||||
sources[id] = deque()
|
||||
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
# All sources must have a message to be ready/done
|
||||
sources['a'].append(to_dt("datetime"))
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
sources['a'].pop()
|
||||
|
||||
for id in ids:
|
||||
sources[id].append(to_dt("datetime"))
|
||||
|
||||
assert ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].appendleft(to_dt("DONE"))
|
||||
|
||||
# ["DONE", message] will trip an assert in queue_is_done.
|
||||
assert ready(sources)
|
||||
self.assertRaises(AssertionError, done, sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].pop()
|
||||
|
||||
assert ready(sources)
|
||||
assert done(sources)
|
||||
|
||||
class DateSortTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def run_date_sort(self, event_stream, expected, source_ids):
|
||||
"""
|
||||
Take a list of events, their source_ids, and an expected sorting.
|
||||
Assert that date_sort's output agrees with expected.
|
||||
"""
|
||||
sort_out = date_sort(event_stream, source_ids)
|
||||
for m1, m2 in izip_longest(sort_out, expected):
|
||||
assert m1 == m2
|
||||
|
||||
def test_single_source(self):
|
||||
|
||||
# Just using the built-in defaults. See
|
||||
# zipline/gens/tradegens.py
|
||||
source = SpecificEquityTrades()
|
||||
expected = list(source)
|
||||
source.rewind()
|
||||
# The raw source doesn't handle done messaging, so we need to
|
||||
# append a done message for sort to work properly.
|
||||
with_done = chain(source, [done_message(source.get_hash())])
|
||||
self.run_date_sort(with_done, expected, [source.get_hash()])
|
||||
|
||||
def test_multi_source(self):
|
||||
|
||||
filter = [2,3]
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 100,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 100,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
all_events = list(chain(source_a, source_b))
|
||||
|
||||
# The expected output is all events, sorted by dt with
|
||||
# source_id as a tiebreaker.
|
||||
expected = sorted(all_events, comp)
|
||||
source_ids = [source_a.get_hash(), source_b.get_hash()]
|
||||
|
||||
# Generating the events list consumes the sources. Rewind them
|
||||
# for testing.
|
||||
source_a.rewind()
|
||||
source_b.rewind()
|
||||
|
||||
# Append a done message to each source.
|
||||
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
||||
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
||||
|
||||
interleaved = alternate(with_done_a, with_done_b)
|
||||
|
||||
# Test sort with alternating messages from source_a and
|
||||
# source_b.
|
||||
self.run_date_sort(interleaved, expected, source_ids)
|
||||
|
||||
source_a.rewind()
|
||||
source_b.rewind()
|
||||
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
||||
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
||||
|
||||
sequential = chain(with_done_a, with_done_b)
|
||||
|
||||
# Test sort with all messages from a, followed by all messages
|
||||
# from b.
|
||||
|
||||
self.run_date_sort(sequential, expected, source_ids)
|
||||
|
||||
|
||||
def test_sort_composite(self):
|
||||
|
||||
filter = [1,2]
|
||||
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 100,
|
||||
'sids' : [1],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 50,
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'count' : 150,
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_c = SpecificEquityTrades(*args_c, **kwargs_c)
|
||||
# Set up source d. This should produce no events because the
|
||||
# internal sids don't match the filter.
|
||||
args_d = tuple()
|
||||
kwargs_d = {
|
||||
'count' : 50,
|
||||
'sids' : [3],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_d = SpecificEquityTrades(*args_d, **kwargs_d)
|
||||
sources = [source_a, source_b, source_c, source_d]
|
||||
hashes = [source.get_hash() for source in sources]
|
||||
|
||||
sort_out = date_sorted_sources(*sources)
|
||||
|
||||
# Read all the values from sort and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(sort_out)
|
||||
copy = to_list[:]
|
||||
|
||||
# We should have 300 events (100 from a, 150 from b, 50 from c)
|
||||
assert len(to_list) == 300
|
||||
|
||||
for e in to_list:
|
||||
# All events should match one of our expected source_ids.
|
||||
assert e.source_id in hashes
|
||||
# But none of them should match source_d.
|
||||
assert e.source_id != source_d.get_hash()
|
||||
|
||||
# The events should be sorted by dt, with source_id as tiebreaker.
|
||||
expected = sorted(copy, comp)
|
||||
|
||||
assert to_list == expected
|
||||
|
||||
def compare_by_dt_source_id(x,y):
|
||||
if x.dt < y.dt:
|
||||
return -1
|
||||
elif x.dt > y.dt:
|
||||
return 1
|
||||
|
||||
elif x.source_id < y.source_id:
|
||||
return -1
|
||||
elif x.source_id > y.source_id:
|
||||
return 1
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
#Alias for ease of use
|
||||
comp = compare_by_dt_source_id
|
||||
+36
-15
@@ -13,8 +13,9 @@ TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs'])
|
||||
|
||||
def date_sorted_sources(*sources):
|
||||
"""
|
||||
Takes an iterable of SortBundles, generating namestrings and initialized datasources
|
||||
for each before piping them into a date_sort.
|
||||
Takes an iterable of SortBundles, generating namestrings and
|
||||
initialized datasources for each before piping them into a
|
||||
date_sort.
|
||||
"""
|
||||
|
||||
for source in sources:
|
||||
@@ -28,21 +29,21 @@ def date_sorted_sources(*sources):
|
||||
# one element at a time from each.
|
||||
stream_in = roundrobin(sources, names)
|
||||
|
||||
# Guarantee the flat stream will be sorted by date, using source_id as
|
||||
# tie-breaker, which is fully deterministic (given deterministic string
|
||||
# representation for all args/kwargs)
|
||||
# Guarantee the flat stream will be sorted by date, using
|
||||
# source_id as tie-breaker, which is fully deterministic (given
|
||||
# deterministic string representation for all args/kwargs)
|
||||
|
||||
return date_sort(stream_in, names)
|
||||
|
||||
def merged_transforms(sorted_stream, *transforms):
|
||||
"""
|
||||
A generator that takes the expected output of a date_sort, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
merge to output a unified stream. tnfms should be a list of
|
||||
pointers to generator functions. tnfm_args should be a list of
|
||||
tuples, representing the arguments to be passed to each transform.
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
A generator that takes the expected output of a date_sort, pipes
|
||||
it through a given set of transforms, and runs the results
|
||||
through a merge to output a unified stream. tnfms should be a
|
||||
list of pointers to generator functions. tnfm_args should be a
|
||||
list of tuples, representing the arguments to be passed to each
|
||||
transform. tnfm_kwargs should be a list of dictionaries
|
||||
representing keyword arguments to each transform.
|
||||
"""
|
||||
for transform in transforms:
|
||||
assert isinstance(transform, StatefulTransform)
|
||||
@@ -62,15 +63,35 @@ def merged_transforms(sorted_stream, *transforms):
|
||||
# Roundrobin the outputs of our transforms to create a single flat
|
||||
# stream.
|
||||
to_merge = roundrobin(tnfm_gens, namestrings)
|
||||
|
||||
# Pipe the stream into merge.
|
||||
merged = merge(to_merge, namestrings)
|
||||
# Return the merged events.
|
||||
return merged
|
||||
|
||||
def zipline(sources, transforms, endpoint):
|
||||
assert isinstance(sources, (list, tuple))
|
||||
def sequential_transforms(stream_in, *transforms):
|
||||
"""
|
||||
Apply each transform in transforms sequentially to each event in stream_in.
|
||||
Each transform application will add a new entry indexed to the transform's
|
||||
hash string.
|
||||
"""
|
||||
|
||||
assert isinstance(transforms, (list, tuple))
|
||||
for tnfm in transforms:
|
||||
tnfm.forward_all = False
|
||||
tnfm.update_in_place = False
|
||||
tnfm.append_value = True
|
||||
|
||||
# Recursively apply all transforms to the stream.
|
||||
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
|
||||
transforms,
|
||||
stream_in)
|
||||
return stream_out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
+55
-15
@@ -1,14 +1,16 @@
|
||||
import pytz
|
||||
import time
|
||||
|
||||
from time import sleep
|
||||
from pprint import pprint as pp
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import izip
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
from zipline.gens.composites import SourceBundle, TransformBundle, \
|
||||
date_sorted_sources, merged_transforms
|
||||
date_sorted_sources, merged_transforms, sequential_transforms
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
@@ -18,43 +20,81 @@ import zipline.protocol as zp
|
||||
if __name__ == "__main__":
|
||||
|
||||
filter = [2,3]
|
||||
#Set up source a. One minute between events.
|
||||
#Set up source a. Six minutes between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 325,
|
||||
'count' : 1000,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 6),
|
||||
'delta' : timedelta(minutes = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
#Set up source b. Five minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 7500,
|
||||
'count' : 1000,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
sorted = date_sorted_sources(source_a, source_b)
|
||||
sorted_prime = date_sorted_sources(
|
||||
source_a_prime,
|
||||
source_b_prime
|
||||
)
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(MovingAverage, timedelta(minutes = 20), ['price'])
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
timedelta(minutes = 20),
|
||||
['price']
|
||||
)
|
||||
|
||||
passthrough_prime = StatefulTransform(Passthrough)
|
||||
mavg_price_prime = StatefulTransform(
|
||||
MovingAverage,
|
||||
timedelta(minutes = 20),
|
||||
['price']
|
||||
)
|
||||
|
||||
merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
start = time.time()
|
||||
for message in merged:
|
||||
assert 1 + 1 == 2
|
||||
stop = time.time()
|
||||
merge_time = stop - start
|
||||
print "Merge time: %s" % str(merge_time)
|
||||
|
||||
sequential = sequential_transforms(
|
||||
sorted_prime,
|
||||
passthrough_prime,
|
||||
mavg_price_prime
|
||||
)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
environment = create_trading_environment(year = 2012)
|
||||
style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
start = time.time()
|
||||
for message in sequential:
|
||||
assert 1 + 1 == 2
|
||||
stop = time.time()
|
||||
seq_time = stop - start
|
||||
print "Sequential time: %s" % str(seq_time)
|
||||
print "Merge/Seq: %s" % (str(merge_time/seq_time))
|
||||
|
||||
trading_client = tsc(algo, environment, style)
|
||||
|
||||
# merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
for message in trading_client.simulate(merged):
|
||||
pp(message)
|
||||
# algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
# environment = create_trading_environment(year = 2012)
|
||||
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
# trading_client = tsc(algo, environment, style)
|
||||
|
||||
# for message in trading_client.simulate(merged):
|
||||
# pp(message)
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from itertools import repeat
|
||||
|
||||
def merge(stream_in, tnfm_ids):
|
||||
"""
|
||||
A generator that takes a generator and a list of source_ids. We
|
||||
maintain an internal queue for each id in source_ids. Once we
|
||||
A generator that takes a generator and a list of transform ids. We
|
||||
maintain an internal queue for each id in tnfm_ids. Once we
|
||||
have a message from every queue, we pop an event from each queue
|
||||
and merge them together into an event. We raise an error if we
|
||||
do not receive the same number of events from all sources.
|
||||
@@ -54,9 +54,8 @@ def merge(stream_in, tnfm_ids):
|
||||
yield done_message('Merge')
|
||||
|
||||
def merge_one(sources):
|
||||
dict_primer = zip(sources.keys(), repeat(None))
|
||||
event_fields = ndict()
|
||||
|
||||
event_fields = ndict()
|
||||
for key, queue in sources.iteritems():
|
||||
|
||||
# Add transform value to the transforms dict.
|
||||
|
||||
@@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids):
|
||||
have messages pending from all sources, we pull the earliest
|
||||
message and yield it.
|
||||
"""
|
||||
|
||||
assert isinstance(source_ids, (list, tuple))
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
@@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids):
|
||||
message = pop_oldest(sources)
|
||||
assert_sort_protocol(message)
|
||||
yield message
|
||||
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
|
||||
|
||||
@@ -56,9 +56,12 @@ class StatefulTransform(object):
|
||||
|
||||
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
|
||||
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
|
||||
self.append_value = tnfm_class.__dict__.get('APPENDER', False)
|
||||
|
||||
# You can't be both a forwarded and an updater.
|
||||
assert not all([self.forward_all, self.update_in_place])
|
||||
# You only one special behavior mode can be set.
|
||||
assert sum(map(int, [self.forward_all,
|
||||
self.update_in_place,
|
||||
self.append_value])) <= 1
|
||||
|
||||
# Create an instance of our transform class.
|
||||
self.state = tnfm_class(*args, **kwargs)
|
||||
@@ -75,11 +78,15 @@ class StatefulTransform(object):
|
||||
def _gen(self, stream_in):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
|
||||
for message in stream_in:
|
||||
|
||||
# allow upstream generators to yield None to avoid
|
||||
# blocking.
|
||||
if message == None:
|
||||
continue
|
||||
|
||||
#TODO: refactor this to avoid unnecessary copying.
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
message_copy = deepcopy(message)
|
||||
@@ -87,22 +94,43 @@ class StatefulTransform(object):
|
||||
# Same shared pointer issue here as above.
|
||||
tnfm_value = self.state.update(deepcopy(message_copy))
|
||||
|
||||
# If we want to keep all original values, plus append tnfm_id
|
||||
# and tnfm_value. Used for Passthrough.
|
||||
# FORWARDER flag means we want to keep all original
|
||||
# values, plus append tnfm_id and tnfm_value. Used for
|
||||
# preserving the original event fields when our output
|
||||
# will be fed into a merge.
|
||||
if self.forward_all:
|
||||
out_message = message_copy
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# Our expectation is that the transform simply updated the
|
||||
# message it was passed. Useful for chaining together
|
||||
# multiple transforms, e.g. TransactionSimulator/PerformanceTracker.
|
||||
# UPDATER flag should be used for transforms that
|
||||
# side-effectfully modify the event they are passed.
|
||||
# Updated messages are passed along exactly as they are
|
||||
# returned to use by our state class. Useful for chaining
|
||||
# specific transforms that won't be fed to a merge. (See
|
||||
# the implementation of TradeSimulationClient for example
|
||||
# usage of this flag with PerformanceTracker and
|
||||
# TransactionSimulator.
|
||||
elif self.update_in_place:
|
||||
yield tnfm_value
|
||||
|
||||
# APPENDER flag should be used to add a single new
|
||||
# key-value pair to the event. The new key is this
|
||||
# transform's namestring, and it's value is the value
|
||||
# returned by state.update(event). This is almost
|
||||
# identical to the behavior of FORWARDER, except we
|
||||
# compress the two calculated values (tnfm_id, and
|
||||
# tnfm_value) into a single field.
|
||||
elif self.append_value:
|
||||
out_message = message_copy
|
||||
out_message[self.namestring] = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# Otherwise send tnfm_id, tnfm_value, and the message
|
||||
# date. Useful for transforms being piped to a merge.
|
||||
# If no flags are set, we create a new message containing
|
||||
# just the tnfm_id, the event's datetime, and the
|
||||
# calculated tnfm_value. This is the default behavior for
|
||||
# a transform being fed into a merge.
|
||||
else:
|
||||
out_message = ndict()
|
||||
out_message.tnfm_id = self.namestring
|
||||
|
||||
@@ -66,6 +66,14 @@ def hash_args(*args, **kwargs):
|
||||
hasher.update(combined)
|
||||
return hasher.hexdigest()
|
||||
|
||||
def sum_true(bool_iterable):
|
||||
"""
|
||||
Takes an iterable of boolean values and returns the number of
|
||||
those values that are True.
|
||||
"""
|
||||
return sum(map(int, bool_iterable))
|
||||
|
||||
|
||||
def assert_datasource_protocol(event):
|
||||
"""Assert that an event meets the protocol for datasource outputs."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user