mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 19:47:13 +08:00
132 lines
3.0 KiB
Python
132 lines
3.0 KiB
Python
from contextlib import contextmanager
|
|
from logbook import FileHandler
|
|
from zipline.finance.blotter import ORDER_STATUS
|
|
|
|
from six import itervalues
|
|
|
|
import pandas as pd
|
|
|
|
|
|
def to_utc(time_str):
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
|
|
|
|
|
|
def setup_logger(test, path='test.log'):
|
|
test.log_handler = FileHandler(path)
|
|
test.log_handler.push_application()
|
|
|
|
|
|
def teardown_logger(test):
|
|
test.log_handler.pop_application()
|
|
test.log_handler.close()
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
# confirm that all orders were filled.
|
|
# iterate over the output updates, overwriting
|
|
# orders when they are updated. Then check the status on all.
|
|
orders_by_id = {}
|
|
for update in output:
|
|
if 'daily_perf' in update:
|
|
if 'orders' in update['daily_perf']:
|
|
for order in update['daily_perf']['orders']:
|
|
orders_by_id[order['id']] = order
|
|
|
|
for order in itervalues(orders_by_id):
|
|
test.assertEqual(
|
|
order['status'],
|
|
ORDER_STATUS.FILLED,
|
|
"")
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
def __next__(self):
|
|
5 / 0
|
|
|
|
|
|
class ExceptionTransform(object):
|
|
|
|
def __init__(self):
|
|
self.window_length = 1
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionTransform"
|
|
|
|
def update(self, event):
|
|
assert False, "An assertion message"
|
|
|
|
|
|
@contextmanager
|
|
def nullctx():
|
|
"""
|
|
Null context manager. Useful for conditionally adding a contextmanager in
|
|
a single line, e.g.:
|
|
|
|
with SomeContextManager() if some_expr else nullcontext:
|
|
do_stuff()
|
|
"""
|
|
yield
|