adding new finance test cases and a factories module to support them

This commit is contained in:
fawce
2012-02-26 20:06:17 -05:00
parent 6c6b8d5dae
commit e8567d9305
3 changed files with 230 additions and 4 deletions
+38 -4
View File
@@ -9,7 +9,7 @@ import zipline.messaging as qmsg
class RandomEquityTrades(qmsg.DataSource):
"""Generates a random stream of trades for testing."""
def __init__(self, sid, source_id, count):
qmsg.DataSource.__init__(self, source_id)
self.count = count
@@ -18,14 +18,17 @@ class RandomEquityTrades(qmsg.DataSource):
self.trade_start = datetime.datetime.now()
self.minute = datetime.timedelta(minutes=1)
self.price = random.uniform(5.0, 50.0)
def get_type(self):
return 'equity_trade'
def do_work(self):
if(self.incr == self.count):
self.signal_done()
return
self.price = self.price + random.uniform(-0.05, 0.05)
event = {
'sid' : self.sid,
@@ -33,6 +36,37 @@ class RandomEquityTrades(qmsg.DataSource):
'price' : self.price,
'volume' : random.randrange(100,10000,100)
}
self.send(event)
self.incr += 1
class SpecificEquityTrades(qmsg.DataSource):
"""Generates a random stream of trades for testing."""
def __init__(self, source_id, event_list):
"""
:event_list: should be a chronologically ordered list of dictionaries in the following form:
event = {
'sid' : self.sid,
'dt' : qutil.format_date(self.trade_start + (self.minute * self.incr)),
'price' : self.price,
'volume' : random.randrange(100,10000,100)
}
"""
qmsg.DataSource.__init__(self, source_id)
self.event_list = event_list
def get_type(self):
return 'equity_trade'
def do_work(self):
if(len(self.event_list) == 0):
self.signal_done()
return
event = self.event_list.pop(0)
self.send(event)
+107
View File
@@ -0,0 +1,107 @@
import datetime
import pytz
from algorithm.quantoenv import *
from algorithm.quantomodels import *
from algorithm.hostedalgorithm import *
from algorithm.risk import *
def createReturns(daycount, start):
i = 0
test_range = []
current = start.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
while i < daycount:
i += 1
r = daily_return(current, random.random())
test_range.append(r)
current = current + one_day
return [ x for x in test_range if(trading_calendar.is_trading_day(x.date)) ]
def createReturnsFromRange(start, end):
current = start.replace(tzinfo=pytz.utc)
end = end.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
test_range = []
i = 0
while current <= end:
current = current + one_day
if(not trading_calendar.is_trading_day(current)):
continue
r = daily_return(current, random.random())
i += 1
test_range.append(r)
return test_range
def createReturnsFromList(returns, start):
current = start.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
test_range = []
i = 0
while len(test_range) < len(returns):
if(trading_calendar.is_trading_day(current)):
r = daily_return(current, returns[i])
i += 1
test_range.append(r)
current = current + one_day
return test_range
def createAlgo(filename):
algo = Algorithm()
algo.code = getCodeFromFile(filename)
algo.title = filename
algo._id = pymongo.objectid.ObjectId()
hostedAlgo = HostedAlgorithm(algo)
return hostedAlgo
def getCodeFromFile(filename):
rVal = None
with open('./test/algo_samples/' + filename, 'r') as f:
rVal = f.read()
return rVal
def createTrade(sid, price, amount, datetime):
row = {}
row['sid'] = sid
row['dt'] = datetime
row['price'] = price
row['volume'] = amount
row['exchange_code'] = "fake exchange"
db = getTickDB()
db.equity.trades.minute.insert(row,safe=True)
dw = DocWrap()
dw.store = row
return dw
def createTradeHistory(sid, priceList, amtList, startTime, interval):
i = 0
trades = []
current = startTime
while i < len(priceList):
if(trading_calendar.is_trading_day(current)):
trades.append(createTrade(sid, priceList[i], amtList[i], current))
current = current + interval
i += 1
else:
current = current + datetime.timedelta(days=1)
return trades
def createTxn(sid, price, amount, datetime, btrid=None):
txn = Transaction(sid=sid, amount=amount, dt = datetime,
price=price, transaction_cost=-1*price*amount)
return txn
def createTxnHistory(sid, priceList, amtList, startTime, interval):
i = 0
txns = []
current = startTime
while i < len(priceList):
if(trading_calendar.is_trading_day(current)):
txns.append(createTxn(sid,priceList[i],amtList[i], current))
current = current + interval
i += 1
else:
current = current + datetime.timedelta(days=1)
return txns
+85
View File
@@ -0,0 +1,85 @@
"""Tests for the zipline.finance package"""
from unittest2 import TestCase
from zipline.test.test_messaging import SimulatorTestCase
from zipline.monitor import Controller
from zipline.messaging import DataSource
import zipline.util as qutil
class ThreadPoolExecutor(SimulatorTestCase, TestCase):
allocator = DummyAllocator(100)
def setup_logging(self):
qutil.configure_logging()
# lazy import by design
self.logger = mock.Mock()
def setup_allocator(self):
pass
def get_simulator(self, addresses):
return ThreadSimulator(addresses)
def get_controller(self):
# Allocate two more sockets
controller_sockets = self.allocate_sockets(2)
return Controller(
controller_sockets[0],
controller_sockets[1],
logging = self.logger,
)
#
def test_orders(self):
# Base Simuation
# --------------
# Allocate sockets for the simulator components
sockets = self.allocate_sockets(6)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
}
sim = self.get_simulator(addresses)
con = self.get_controller()
# Simulation Components
# ---------------------
ret1 = RandomEquityTrades(133, "ret1", 5000)
ret2 = RandomEquityTrades(134, "ret2", 5000)
mavg1 = MovingAverage("mavg1", 30)
mavg2 = MovingAverage("mavg2", 60)
client = TestClient(self, expected_msg_count=10000)
sim.register_components([ret1, ret2, mavg1, mavg2, client])
sim.register_controller( con )
# Simulation
# ----------
sim.simulate()
# Stop Running
# ------------
# TODO: less abrupt later, just shove a StopIteration
# down the pipe to make it stop spinning
sim.cuc._Thread__stop()
self.assertEqual(sim.feed.pending_messages(), 0,
"The feed should be drained of all messages, found {n} remaining."
.format(n=sim.feed.pending_messages())
)
class PredefinedDataSource()