mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 06:28:42 +08:00
0c96915404
Instead log to test.log in working directory when running tests. Also, removes config file for logging module, that is no longer used since we are now using LogBook.
140 lines
3.5 KiB
Python
140 lines
3.5 KiB
Python
from datetime import datetime
|
|
import blist
|
|
from zipline.utils.date_utils import EPOCH
|
|
from itertools import izip
|
|
from logbook import FileHandler
|
|
|
|
|
|
def setup_logger(test, path='test.log'):
|
|
test.log_handler = FileHandler(path)
|
|
test.log_handler.push_application()
|
|
|
|
|
|
def teardown_logger(test):
|
|
test.log_handler.pop_application()
|
|
test.log_handler.close()
|
|
|
|
|
|
def check_list(test, a, b, label):
|
|
test.assertTrue(isinstance(a, (list, blist.blist)))
|
|
test.assertTrue(isinstance(b, (list, blist.blist)))
|
|
i = 0
|
|
for a_val, b_val in izip(a, b):
|
|
check(test, a_val, b_val, label + "[" + str(i) + "]")
|
|
|
|
|
|
def check_dict(test, a, b, label):
|
|
test.assertTrue(isinstance(a, dict))
|
|
test.assertTrue(isinstance(b, dict))
|
|
for key in a.keys():
|
|
# ignore the extra fields used by dictshield
|
|
if key in ['progress']:
|
|
continue
|
|
|
|
test.assertTrue(key in a, "missing key at: " + label + "." + key)
|
|
test.assertTrue(key in b, "missing key at: " + label + "." + key)
|
|
a_val = a[key]
|
|
b_val = b[key]
|
|
check(test, a_val, b_val, label + "." + key)
|
|
|
|
|
|
def check_datetime(test, a, b, label):
|
|
test.assertTrue(isinstance(a, datetime))
|
|
test.assertTrue(isinstance(b, datetime))
|
|
test.assertEqual(EPOCH(a), EPOCH(b), "mismatched dates " + label)
|
|
|
|
|
|
def check(test, a, b, label=None):
|
|
"""
|
|
Check equality for arbitrarily nested dicts and lists that terminate
|
|
in types that allow direct comparisons (string, ints, floats, datetimes)
|
|
"""
|
|
if not label:
|
|
label = '<root>'
|
|
if isinstance(a, dict):
|
|
check_dict(test, a, b, label)
|
|
elif isinstance(a, (list, blist.blist)):
|
|
check_list(test, a, b, label)
|
|
elif isinstance(a, datetime):
|
|
check_datetime(test, a, b, label)
|
|
else:
|
|
test.assertEqual(a, b, "mismatch on path: " + label)
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
|
|
class ExceptionTransform(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionTransform"
|
|
|
|
def update(self, event):
|
|
assert False, "An assertion message"
|