mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:30:28 +08:00
132 lines
3.3 KiB
Python
132 lines
3.3 KiB
Python
from datetime import datetime
|
|
import blist
|
|
from zipline.utils.date_utils import EPOCH
|
|
from itertools import izip
|
|
from logbook import FileHandler
|
|
|
|
|
|
def setup_logger(test, path='/var/log/zipline/zipline.log'):
|
|
test.log_handler = FileHandler(path)
|
|
test.log_handler.push_application()
|
|
|
|
|
|
def teardown_logger(test):
|
|
test.log_handler.pop_application()
|
|
test.log_handler.close()
|
|
|
|
|
|
def check_list(test, a, b, label):
|
|
test.assertTrue(isinstance(a, (list, blist.blist)))
|
|
test.assertTrue(isinstance(b, (list, blist.blist)))
|
|
i = 0
|
|
for a_val, b_val in izip(a, b):
|
|
check(test, a_val, b_val, label + "[" + str(i) + "]")
|
|
|
|
|
|
def check_dict(test, a, b, label):
|
|
test.assertTrue(isinstance(a, dict))
|
|
test.assertTrue(isinstance(b, dict))
|
|
for key in a.keys():
|
|
# ignore the extra fields used by dictshield
|
|
if key in ['progress']:
|
|
continue
|
|
|
|
test.assertTrue(key in a, "missing key at: " + label + "." + key)
|
|
test.assertTrue(key in b, "missing key at: " + label + "." + key)
|
|
a_val = a[key]
|
|
b_val = b[key]
|
|
check(test, a_val, b_val, label + "." + key)
|
|
|
|
|
|
def check_datetime(test, a, b, label):
|
|
test.assertTrue(isinstance(a, datetime))
|
|
test.assertTrue(isinstance(b, datetime))
|
|
test.assertEqual(EPOCH(a), EPOCH(b), "mismatched dates " + label)
|
|
|
|
|
|
def check(test, a, b, label=None):
|
|
"""
|
|
Check equality for arbitrarily nested dicts and lists that terminate
|
|
in types that allow direct comparisons (string, ints, floats, datetimes)
|
|
"""
|
|
if not label:
|
|
label = '<root>'
|
|
if isinstance(a, dict):
|
|
check_dict(test, a, b, label)
|
|
elif isinstance(a, (list, blist.blist)):
|
|
check_list(test, a, b, label)
|
|
elif isinstance(a, datetime):
|
|
check_datetime(test, a, b, label)
|
|
else:
|
|
test.assertEqual(a, b, "mismatch on path: " + label)
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
|
|
class ExceptionTransform(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionTransform"
|
|
|
|
def update(self, event):
|
|
assert False, "An assertion message"
|