mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 06:28:38 +08:00
added tests for filter, refactoring factory a bit to make sure N trades are created when N are requested.
This commit is contained in:
@@ -95,7 +95,8 @@ class TradeSimulationClient(qmsg.Component):
|
||||
|
||||
def run_algorithm(self):
|
||||
frame = self.get_frame()
|
||||
self.algorithm.handle_frame(frame)
|
||||
if len(frame) > 0:
|
||||
self.algorithm.handle_frame(frame)
|
||||
|
||||
def connect_order(self):
|
||||
return self.connect_push_socket(self.addresses['order_address'])
|
||||
@@ -129,6 +130,7 @@ class TradeSimulationClient(qmsg.Component):
|
||||
def get_frame(self):
|
||||
for event in self.event_queue:
|
||||
self.event_frame[event['sid']] = event
|
||||
self.event_queue = []
|
||||
return self.event_frame
|
||||
|
||||
class OrderDataSource(qmsg.DataSource):
|
||||
|
||||
+1
-1
@@ -261,7 +261,7 @@ class SimulatedTrading(object):
|
||||
"""
|
||||
assert isinstance(source, zmsg.DataSource)
|
||||
self.check_started()
|
||||
source.set_filter('SID', self.algorithm.get_sid_filter)
|
||||
source.set_filter('SID', self.algorithm.get_sid_filter())
|
||||
self.sim.register_components([source])
|
||||
self.sources[source.get_id] = source
|
||||
|
||||
|
||||
+9
-3
@@ -13,12 +13,19 @@ class TradeDataSource(zm.DataSource):
|
||||
|
||||
def send(self, event):
|
||||
"""
|
||||
Sends the event iff it matches the internal SID filter.
|
||||
:param dict event: is a trade event with data as per
|
||||
:py:func: `zipline.protocol.TRADE_FRAME`
|
||||
:rtype: None
|
||||
"""
|
||||
|
||||
event.source_id = self.get_id
|
||||
message = zp.DATASOURCE_FRAME(event)
|
||||
if event.sid in self.filter['SID']:
|
||||
|
||||
message = zp.DATASOURCE_FRAME(event)
|
||||
else:
|
||||
message = zp.DATASOURCE_FRAME(None)
|
||||
|
||||
self.data_socket.send(message)
|
||||
|
||||
|
||||
@@ -56,9 +63,8 @@ class RandomEquityTrades(TradeDataSource):
|
||||
"dt" : self.trade_start + (self.minute * self.incr),
|
||||
})
|
||||
|
||||
self.send(event)
|
||||
|
||||
self.incr += 1
|
||||
|
||||
|
||||
|
||||
class SpecificEquityTrades(TradeDataSource):
|
||||
|
||||
@@ -50,11 +50,13 @@ class TestAlgorithm():
|
||||
self.incr = 0
|
||||
self.done = False
|
||||
self.order = None
|
||||
self.frame_count = 0
|
||||
|
||||
def set_order(self, order_callable):
|
||||
self.order = order_callable
|
||||
|
||||
def handle_frame(self, frame):
|
||||
self.frame_count += 1
|
||||
for dt, s in frame.iteritems():
|
||||
data = {}
|
||||
data.update(s)
|
||||
|
||||
+31
-28
@@ -63,20 +63,27 @@ def create_trade(sid, price, amount, datetime):
|
||||
})
|
||||
return row
|
||||
|
||||
def get_next_trading_dt(current, interval, trading_calendar):
|
||||
next = current
|
||||
while True:
|
||||
next = next + interval
|
||||
if trading_calendar.is_trading_day(next):
|
||||
break
|
||||
else:
|
||||
next = next + timedelta(days=1)
|
||||
|
||||
return next
|
||||
|
||||
def create_trade_history(sid, prices, amounts, start_time, interval, trading_calendar):
|
||||
i = 0
|
||||
trades = []
|
||||
current = start_time.replace(tzinfo = pytz.utc)
|
||||
|
||||
for price, amount in zip(prices, amounts):
|
||||
|
||||
if(trading_calendar.is_trading_day(current)):
|
||||
trade = create_trade(sid, price, amount, current)
|
||||
trades.append(trade)
|
||||
|
||||
current = current + interval
|
||||
else:
|
||||
current = current + timedelta(days=1)
|
||||
|
||||
current = get_next_trading_dt(current, interval, trading_calendar)
|
||||
trade = create_trade(sid, price, amount, current)
|
||||
trades.append(trade)
|
||||
|
||||
return trades
|
||||
|
||||
@@ -94,14 +101,10 @@ def create_txn_history(sid, priceList, amtList, startTime, interval, trading_cal
|
||||
current = startTime
|
||||
|
||||
for price, amount in zip(priceList, amtList):
|
||||
current = get_next_trading_dt(current, interval, trading_calendar)
|
||||
|
||||
if trading_calendar.is_trading_day(current):
|
||||
txns.append(create_txn(sid, price, amount, current))
|
||||
current = current + interval
|
||||
|
||||
else:
|
||||
current = current + timedelta(days=1)
|
||||
|
||||
txns.append(create_txn(sid, price, amount, current))
|
||||
current = current + interval
|
||||
return txns
|
||||
|
||||
|
||||
@@ -111,11 +114,16 @@ def create_returns(daycount, start, trading_calendar):
|
||||
current = start.replace(tzinfo=pytz.utc)
|
||||
one_day = timedelta(days = 1)
|
||||
while i < daycount:
|
||||
current = get_next_trading_dt(
|
||||
current,
|
||||
one_day,
|
||||
trading_calendar
|
||||
)
|
||||
i += 1
|
||||
r = risk.DailyReturn(current, random.random())
|
||||
test_range.append(r)
|
||||
current = current + one_day
|
||||
return [ x for x in test_range if(trading_calendar.is_trading_day(x.date)) ]
|
||||
|
||||
return test_range
|
||||
|
||||
|
||||
def create_returns_from_range(start, end, trading_calendar):
|
||||
@@ -123,14 +131,11 @@ def create_returns_from_range(start, end, trading_calendar):
|
||||
end = end.replace(tzinfo=pytz.utc)
|
||||
one_day = timedelta(days = 1)
|
||||
test_range = []
|
||||
i = 0
|
||||
while current <= end:
|
||||
current = current + one_day
|
||||
if(not trading_calendar.is_trading_day(current)):
|
||||
continue
|
||||
current = get_next_trading_dt(current, one_day, trading_calender)
|
||||
r = risk.DailyReturn(current, random.random())
|
||||
i += 1
|
||||
test_range.append(r)
|
||||
|
||||
|
||||
return test_range
|
||||
|
||||
@@ -138,13 +143,11 @@ def create_returns_from_list(returns, start, trading_calendar):
|
||||
current = start.replace(tzinfo=pytz.utc)
|
||||
one_day = timedelta(days = 1)
|
||||
test_range = []
|
||||
i = 0
|
||||
while len(test_range) < len(returns):
|
||||
if(trading_calendar.is_trading_day(current)):
|
||||
r = risk.DailyReturn(current, returns[i])
|
||||
i += 1
|
||||
test_range.append(r)
|
||||
current = current + one_day
|
||||
current = get_next_trading_dt(current, one_day, trading_calendar)
|
||||
r = risk.DailyReturn(current, returns[i])
|
||||
test_range.append(r)
|
||||
|
||||
return sorted(test_range, key=lambda(x):x.date)
|
||||
|
||||
def create_daily_trade_source(sids, trade_count, trading_environment):
|
||||
|
||||
@@ -107,3 +107,43 @@ class FinanceTestCase(TestCase):
|
||||
SID,
|
||||
"Portfolio should have one position in " + str(SID)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
zipline.algorithm.frame_count,
|
||||
self.zipline_test_config['trade_count'],
|
||||
"The algorithm should receive all trades."
|
||||
)
|
||||
|
||||
@timed(DEFAULT_TIMEOUT)
|
||||
def test_sid_filter(self):
|
||||
"""Ensure the algorithm's filter prevents events from arriving."""
|
||||
# create a test algorithm whose filter will not match any of the
|
||||
# trade events sourced inside the zipline.
|
||||
order_amount = 100
|
||||
order_count = 100
|
||||
no_match_sid = 222
|
||||
test_algo = TestAlgorithm(
|
||||
no_match_sid,
|
||||
order_amount,
|
||||
order_count
|
||||
)
|
||||
|
||||
self.zipline_test_config['trade_count'] = 200
|
||||
self.zipline_test_config['algorithm'] = test_algo
|
||||
|
||||
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
|
||||
|
||||
zipline.simulate(blocking=True)
|
||||
|
||||
#check that the algorithm received no events
|
||||
self.assertEqual(
|
||||
0,
|
||||
test_algo.frame_count,
|
||||
"The algorithm should not receive any events due to filtering."
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user