mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 13:13:41 +08:00
2c7355a0dc
Global state for the financial simulation environment is accessed through the
zipline.finance.trading module, which now contains a module variable:
environment.
Parameters are passed into an algorithm as a keyword argument, sim_params.
SimulationParameters creates a trading day index for the test period that
can be used to find trading days, calculate distance between trading days,
and other common operations. The sim params index is just selected from the
global state.
================
Details:
- adding delorean to the requirements.
- made index symbol a parameter for loading the benchmark data. changed
messagepack storage to be symbol specific.
- ported risk, performance, algorithm, transforms, batch transforms
and associated tests to use simulation parameters and global environment
- factory and sim factory use global state and sim params
- factory method parameter names now reflect the class expected
136 lines
3.4 KiB
Python
136 lines
3.4 KiB
Python
from datetime import datetime
|
|
import blist
|
|
from zipline.utils.date_utils import EPOCH
|
|
from itertools import izip
|
|
from logbook import FileHandler
|
|
|
|
|
|
def setup_logger(test, path='test.log'):
|
|
test.log_handler = FileHandler(path)
|
|
test.log_handler.push_application()
|
|
|
|
|
|
def teardown_logger(test):
|
|
test.log_handler.pop_application()
|
|
test.log_handler.close()
|
|
|
|
|
|
def check_list(test, a, b, label):
|
|
test.assertTrue(isinstance(a, (list, blist.blist)))
|
|
test.assertTrue(isinstance(b, (list, blist.blist)))
|
|
for i, (a_val, b_val) in enumerate(izip(a, b)):
|
|
check(test, a_val, b_val, label + "[" + str(i) + "]")
|
|
|
|
|
|
def check_dict(test, a, b, label):
|
|
test.assertTrue(isinstance(a, dict))
|
|
test.assertTrue(isinstance(b, dict))
|
|
for key in a.keys():
|
|
|
|
test.assertTrue(key in a, "missing key at: " + label + "." + key)
|
|
test.assertTrue(key in b, "missing key at: " + label + "." + key)
|
|
a_val = a[key]
|
|
b_val = b[key]
|
|
check(test, a_val, b_val, label + "." + key)
|
|
|
|
|
|
def check_datetime(test, a, b, label):
|
|
test.assertTrue(isinstance(a, datetime))
|
|
test.assertTrue(isinstance(b, datetime))
|
|
test.assertEqual(EPOCH(a), EPOCH(b), "mismatched dates " + label)
|
|
|
|
|
|
def check(test, a, b, label=None):
|
|
"""
|
|
Check equality for arbitrarily nested dicts and lists that terminate
|
|
in types that allow direct comparisons (string, ints, floats, datetimes)
|
|
"""
|
|
if not label:
|
|
label = '<root>'
|
|
if isinstance(a, dict):
|
|
check_dict(test, a, b, label)
|
|
elif isinstance(a, (list, blist.blist)):
|
|
check_list(test, a, b, label)
|
|
elif isinstance(a, datetime):
|
|
check_datetime(test, a, b, label)
|
|
else:
|
|
test.assertEqual(a, b, "mismatch on path: " + label)
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
|
|
class ExceptionTransform(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionTransform"
|
|
|
|
def update(self, event):
|
|
assert False, "An assertion message"
|