Files
catalyst/zipline/utils/test_utils.py
T
Jonathan Kamens 16e75a31d4 TST: Fix check() test utility function to detect all list differences
The check() function in zipline.utils.test_utils was only comparing
lists up to the length of the shortest list. This fix uses
izip_longest instead of izip so it compares up to the length of the
longest list, which among other things means that it will now
correctly report when one list is empty and the other is not.
2013-04-30 21:13:18 -04:00

151 lines
4.0 KiB
Python

from datetime import datetime
import blist
from zipline.utils.date_utils import EPOCH
from itertools import izip_longest
from logbook import FileHandler
from zipline.finance.blotter import ORDER_STATUS
def setup_logger(test, path='test.log'):
test.log_handler = FileHandler(path)
test.log_handler.push_application()
def teardown_logger(test):
test.log_handler.pop_application()
test.log_handler.close()
def check_list(test, a, b, label):
test.assertTrue(isinstance(a, (list, blist.blist)))
test.assertTrue(isinstance(b, (list, blist.blist)))
for i, (a_val, b_val) in enumerate(izip_longest(a, b)):
check(test, a_val, b_val, label + "[" + str(i) + "]")
def check_dict(test, a, b, label):
test.assertTrue(isinstance(a, dict))
test.assertTrue(isinstance(b, dict))
test.assertEqual(sorted(a), sorted(b), "different keys at: " + label)
for key in a:
a_val = a[key]
b_val = b[key]
check(test, a_val, b_val, label + "." + key)
def check_datetime(test, a, b, label):
test.assertTrue(isinstance(a, datetime))
test.assertTrue(isinstance(b, datetime))
test.assertEqual(EPOCH(a), EPOCH(b), "mismatched dates " + label)
def check(test, a, b, label=None):
"""
Check equality for arbitrarily nested dicts and lists that terminate
in types that allow direct comparisons (string, ints, floats, datetimes)
"""
if not label:
label = '<root>'
if isinstance(a, dict):
check_dict(test, a, b, label)
elif isinstance(a, (list, blist.blist)):
check_list(test, a, b, label)
elif isinstance(a, datetime):
check_datetime(test, a, b, label)
else:
test.assertEqual(a, b, "mismatch on path: " + label)
def drain_zipline(test, zipline):
output = []
transaction_count = 0
msg_counter = 0
# start the simulation
for update in zipline:
msg_counter += 1
output.append(update)
if 'daily_perf' in update:
transaction_count += \
len(update['daily_perf']['transactions'])
return output, transaction_count
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
if 'expected_transactions' in test.zipline_test_config:
test.assertEqual(
test.zipline_test_config['expected_transactions'],
transaction_count
)
else:
test.assertEqual(
test.zipline_test_config['order_count'],
transaction_count
)
# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
# dicts.
closing_positions = output[-2]['daily_perf']['positions']
# confirm that all orders were filled.
# iterate over the output updates, overwriting
# orders when they are updated. Then check the status on all.
orders_by_id = {}
for update in output:
if 'daily_perf' in update:
if 'orders' in update['daily_perf']:
for order in update['daily_perf']['orders']:
orders_by_id[order['id']] = order
for order in orders_by_id.itervalues():
test.assertEqual(
order['status'],
ORDER_STATUS.FILLED,
"")
test.assertEqual(
len(closing_positions),
1,
"Portfolio should have one position."
)
sid = test.zipline_test_config['sid']
test.assertEqual(
closing_positions[0]['sid'],
sid,
"Portfolio should have one position in " + str(sid)
)
return output, transaction_count
class ExceptionSource(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionSource"
def __iter__(self):
return self
def next(self):
5 / 0
class ExceptionTransform(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionTransform"
def update(self, event):
assert False, "An assertion message"