Files
catalyst/zipline/utils/test_utils.py
T
fawce 2c7355a0dc Refactoring of TradingEnvironment to isolate the global state: index symbol and exchange timezone. Parameters that define the simulation (start, end, and capital base) were put in a new class, SimulationParameters.
Global state for the financial simulation environment is accessed through the
zipline.finance.trading module, which now contains a module variable:
environment.

Parameters are passed into an algorithm as a keyword argument, sim_params.
SimulationParameters creates a trading day index for the test period that
can be used to find trading days, calculate distance between trading days,
and other common operations. The sim params index is just selected from the
global state.

================

Details:

    - adding delorean to the requirements.
    - made index symbol a parameter for loading the benchmark data. changed
    messagepack storage to be symbol specific.
    - ported risk, performance, algorithm, transforms, batch transforms
    and associated tests to use simulation parameters and global environment
    - factory and sim factory use global state and sim params
    - factory method parameter names now reflect the class expected
2013-02-18 10:24:32 -05:00

136 lines
3.4 KiB
Python

from datetime import datetime
import blist
from zipline.utils.date_utils import EPOCH
from itertools import izip
from logbook import FileHandler
def setup_logger(test, path='test.log'):
test.log_handler = FileHandler(path)
test.log_handler.push_application()
def teardown_logger(test):
test.log_handler.pop_application()
test.log_handler.close()
def check_list(test, a, b, label):
test.assertTrue(isinstance(a, (list, blist.blist)))
test.assertTrue(isinstance(b, (list, blist.blist)))
for i, (a_val, b_val) in enumerate(izip(a, b)):
check(test, a_val, b_val, label + "[" + str(i) + "]")
def check_dict(test, a, b, label):
test.assertTrue(isinstance(a, dict))
test.assertTrue(isinstance(b, dict))
for key in a.keys():
test.assertTrue(key in a, "missing key at: " + label + "." + key)
test.assertTrue(key in b, "missing key at: " + label + "." + key)
a_val = a[key]
b_val = b[key]
check(test, a_val, b_val, label + "." + key)
def check_datetime(test, a, b, label):
test.assertTrue(isinstance(a, datetime))
test.assertTrue(isinstance(b, datetime))
test.assertEqual(EPOCH(a), EPOCH(b), "mismatched dates " + label)
def check(test, a, b, label=None):
"""
Check equality for arbitrarily nested dicts and lists that terminate
in types that allow direct comparisons (string, ints, floats, datetimes)
"""
if not label:
label = '<root>'
if isinstance(a, dict):
check_dict(test, a, b, label)
elif isinstance(a, (list, blist.blist)):
check_list(test, a, b, label)
elif isinstance(a, datetime):
check_datetime(test, a, b, label)
else:
test.assertEqual(a, b, "mismatch on path: " + label)
def drain_zipline(test, zipline):
output = []
transaction_count = 0
msg_counter = 0
# start the simulation
for update in zipline:
msg_counter += 1
output.append(update)
if 'daily_perf' in update:
transaction_count += \
len(update['daily_perf']['transactions'])
return output, transaction_count
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
if 'expected_transactions' in test.zipline_test_config:
test.assertEqual(
test.zipline_test_config['expected_transactions'],
transaction_count
)
else:
test.assertEqual(
test.zipline_test_config['order_count'],
transaction_count
)
# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
# dicts.
closing_positions = output[-2]['daily_perf']['positions']
test.assertEqual(
len(closing_positions),
1,
"Portfolio should have one position."
)
sid = test.zipline_test_config['sid']
test.assertEqual(
closing_positions[0]['sid'],
sid,
"Portfolio should have one position in " + str(sid)
)
return output, transaction_count
class ExceptionSource(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionSource"
def __iter__(self):
return self
def next(self):
5 / 0
class ExceptionTransform(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionTransform"
def update(self, event):
assert False, "An assertion message"