added tests for filter, refactoring factory a bit to make sure N trades are created when N are requested.

This commit is contained in:
fawce
2012-03-20 16:38:32 -04:00
parent 26da4316d1
commit f20db3d01d
6 changed files with 86 additions and 33 deletions
+3 -1
View File
@@ -95,7 +95,8 @@ class TradeSimulationClient(qmsg.Component):
def run_algorithm(self):
frame = self.get_frame()
self.algorithm.handle_frame(frame)
if len(frame) > 0:
self.algorithm.handle_frame(frame)
def connect_order(self):
return self.connect_push_socket(self.addresses['order_address'])
@@ -129,6 +130,7 @@ class TradeSimulationClient(qmsg.Component):
def get_frame(self):
for event in self.event_queue:
self.event_frame[event['sid']] = event
self.event_queue = []
return self.event_frame
class OrderDataSource(qmsg.DataSource):
+1 -1
View File
@@ -261,7 +261,7 @@ class SimulatedTrading(object):
"""
assert isinstance(source, zmsg.DataSource)
self.check_started()
source.set_filter('SID', self.algorithm.get_sid_filter)
source.set_filter('SID', self.algorithm.get_sid_filter())
self.sim.register_components([source])
self.sources[source.get_id] = source
+9 -3
View File
@@ -13,12 +13,19 @@ class TradeDataSource(zm.DataSource):
def send(self, event):
"""
Sends the event iff it matches the internal SID filter.
:param dict event: is a trade event with data as per
:py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
event.source_id = self.get_id
message = zp.DATASOURCE_FRAME(event)
if event.sid in self.filter['SID']:
message = zp.DATASOURCE_FRAME(event)
else:
message = zp.DATASOURCE_FRAME(None)
self.data_socket.send(message)
@@ -56,9 +63,8 @@ class RandomEquityTrades(TradeDataSource):
"dt" : self.trade_start + (self.minute * self.incr),
})
self.send(event)
self.incr += 1
class SpecificEquityTrades(TradeDataSource):
+2
View File
@@ -50,11 +50,13 @@ class TestAlgorithm():
self.incr = 0
self.done = False
self.order = None
self.frame_count = 0
def set_order(self, order_callable):
self.order = order_callable
def handle_frame(self, frame):
self.frame_count += 1
for dt, s in frame.iteritems():
data = {}
data.update(s)
+31 -28
View File
@@ -63,20 +63,27 @@ def create_trade(sid, price, amount, datetime):
})
return row
def get_next_trading_dt(current, interval, trading_calendar):
next = current
while True:
next = next + interval
if trading_calendar.is_trading_day(next):
break
else:
next = next + timedelta(days=1)
return next
def create_trade_history(sid, prices, amounts, start_time, interval, trading_calendar):
i = 0
trades = []
current = start_time.replace(tzinfo = pytz.utc)
for price, amount in zip(prices, amounts):
if(trading_calendar.is_trading_day(current)):
trade = create_trade(sid, price, amount, current)
trades.append(trade)
current = current + interval
else:
current = current + timedelta(days=1)
current = get_next_trading_dt(current, interval, trading_calendar)
trade = create_trade(sid, price, amount, current)
trades.append(trade)
return trades
@@ -94,14 +101,10 @@ def create_txn_history(sid, priceList, amtList, startTime, interval, trading_cal
current = startTime
for price, amount in zip(priceList, amtList):
current = get_next_trading_dt(current, interval, trading_calendar)
if trading_calendar.is_trading_day(current):
txns.append(create_txn(sid, price, amount, current))
current = current + interval
else:
current = current + timedelta(days=1)
txns.append(create_txn(sid, price, amount, current))
current = current + interval
return txns
@@ -111,11 +114,16 @@ def create_returns(daycount, start, trading_calendar):
current = start.replace(tzinfo=pytz.utc)
one_day = timedelta(days = 1)
while i < daycount:
current = get_next_trading_dt(
current,
one_day,
trading_calendar
)
i += 1
r = risk.DailyReturn(current, random.random())
test_range.append(r)
current = current + one_day
return [ x for x in test_range if(trading_calendar.is_trading_day(x.date)) ]
return test_range
def create_returns_from_range(start, end, trading_calendar):
@@ -123,14 +131,11 @@ def create_returns_from_range(start, end, trading_calendar):
end = end.replace(tzinfo=pytz.utc)
one_day = timedelta(days = 1)
test_range = []
i = 0
while current <= end:
current = current + one_day
if(not trading_calendar.is_trading_day(current)):
continue
current = get_next_trading_dt(current, one_day, trading_calender)
r = risk.DailyReturn(current, random.random())
i += 1
test_range.append(r)
return test_range
@@ -138,13 +143,11 @@ def create_returns_from_list(returns, start, trading_calendar):
current = start.replace(tzinfo=pytz.utc)
one_day = timedelta(days = 1)
test_range = []
i = 0
while len(test_range) < len(returns):
if(trading_calendar.is_trading_day(current)):
r = risk.DailyReturn(current, returns[i])
i += 1
test_range.append(r)
current = current + one_day
current = get_next_trading_dt(current, one_day, trading_calendar)
r = risk.DailyReturn(current, returns[i])
test_range.append(r)
return sorted(test_range, key=lambda(x):x.date)
def create_daily_trade_source(sids, trade_count, trading_environment):
+40
View File
@@ -107,3 +107,43 @@ class FinanceTestCase(TestCase):
SID,
"Portfolio should have one position in " + str(SID)
)
self.assertEqual(
zipline.algorithm.frame_count,
self.zipline_test_config['trade_count'],
"The algorithm should receive all trades."
)
@timed(DEFAULT_TIMEOUT)
def test_sid_filter(self):
"""Ensure the algorithm's filter prevents events from arriving."""
# create a test algorithm whose filter will not match any of the
# trade events sourced inside the zipline.
order_amount = 100
order_count = 100
no_match_sid = 222
test_algo = TestAlgorithm(
no_match_sid,
order_amount,
order_count
)
self.zipline_test_config['trade_count'] = 200
self.zipline_test_config['algorithm'] = test_algo
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
zipline.simulate(blocking=True)
#check that the algorithm received no events
self.assertEqual(
0,
test_algo.frame_count,
"The algorithm should not receive any events due to filtering."
)