Files
catalyst/zipline/utils/test_utils.py
T
2015-09-16 01:28:16 -04:00

340 lines
9.3 KiB
Python

from contextlib import contextmanager
from itertools import (
product,
)
from logbook import FileHandler
from mock import patch
from numpy.testing import assert_array_equal
import operator
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils import security_list
from six import (
itervalues,
)
from six.moves import filter
import os
import pandas as pd
import shutil
import tempfile
EPOCH = pd.Timestamp(0, tz='UTC')
def seconds_to_timestamp(seconds):
return pd.Timestamp(seconds, unit='s', tz='UTC')
def to_utc(time_str):
"""Convert a string in US/Eastern time to UTC"""
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
def str_to_seconds(s):
"""
Convert a pandas-intelligible string to (integer) seconds since UTC.
>>> from pandas import Timestamp
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
1388534400.0
>>> str_to_seconds('2014-01-01')
1388534400
"""
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
def setup_logger(test, path='test.log'):
test.log_handler = FileHandler(path)
test.log_handler.push_application()
def teardown_logger(test):
test.log_handler.pop_application()
test.log_handler.close()
def drain_zipline(test, zipline):
output = []
transaction_count = 0
msg_counter = 0
# start the simulation
for update in zipline:
msg_counter += 1
output.append(update)
if 'daily_perf' in update:
transaction_count += \
len(update['daily_perf']['transactions'])
return output, transaction_count
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
if 'expected_transactions' in test.zipline_test_config:
test.assertEqual(
test.zipline_test_config['expected_transactions'],
transaction_count
)
else:
test.assertEqual(
test.zipline_test_config['order_count'],
transaction_count
)
# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
# dicts.
closing_positions = output[-2]['daily_perf']['positions']
# confirm that all orders were filled.
# iterate over the output updates, overwriting
# orders when they are updated. Then check the status on all.
orders_by_id = {}
for update in output:
if 'daily_perf' in update:
if 'orders' in update['daily_perf']:
for order in update['daily_perf']['orders']:
orders_by_id[order['id']] = order
for order in itervalues(orders_by_id):
test.assertEqual(
order['status'],
ORDER_STATUS.FILLED,
"")
test.assertEqual(
len(closing_positions),
1,
"Portfolio should have one position."
)
sid = test.zipline_test_config['sid']
test.assertEqual(
closing_positions[0]['sid'],
sid,
"Portfolio should have one position in " + str(sid)
)
return output, transaction_count
class ExceptionSource(object):
def __init__(self):
pass
def get_hash(self):
return "ExceptionSource"
def __iter__(self):
return self
def next(self):
5 / 0
def __next__(self):
5 / 0
@contextmanager
def security_list_copy():
old_dir = security_list.SECURITY_LISTS_DIR
new_dir = tempfile.mkdtemp()
try:
for subdir in os.listdir(old_dir):
shutil.copytree(os.path.join(old_dir, subdir),
os.path.join(new_dir, subdir))
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
patch.object(security_list, 'using_copy', True,
create=True):
yield
finally:
shutil.rmtree(new_dir, True)
def add_security_data(adds, deletes):
if not hasattr(security_list, 'using_copy'):
raise Exception('add_security_data must be used within '
'security_list_copy context')
directory = os.path.join(
security_list.SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/20150125"
)
if not os.path.exists(directory):
os.makedirs(directory)
del_path = os.path.join(directory, "delete")
with open(del_path, 'w') as f:
for sym in deletes:
f.write(sym)
f.write('\n')
add_path = os.path.join(directory, "add")
with open(add_path, 'w') as f:
for sym in adds:
f.write(sym)
f.write('\n')
def all_pairs_matching_predicate(values, pred):
"""
Return an iterator of all pairs, (v0, v1) from values such that
`pred(v0, v1) == True`
Parameters
----------
values : iterable
pred : function
Returns
-------
pairs_iterator : generator
Generator yielding pairs matching `pred`.
Examples
--------
>>> from zipline.utils.test_utils import all_pairs_matching_predicate
>>> from operator import eq, lt
>>> list(all_pairs_matching_predicate(range(5), eq))
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
>>> list(all_pairs_matching_predicate("abcd", lt))
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
"""
return filter(lambda pair: pred(*pair), product(values, repeat=2))
def product_upper_triangle(values, include_diagonal=False):
"""
Return an iterator over pairs, (v0, v1), drawn from values.
If `include_diagonal` is True, returns all pairs such that v0 <= v1.
If `include_diagonal` is False, returns all pairs such that v0 < v1.
"""
return all_pairs_matching_predicate(
values,
operator.le if include_diagonal else operator.lt,
)
def all_subindices(index):
"""
Return all valid sub-indices of a pandas Index.
"""
return (
index[start:stop]
for start, stop in product_upper_triangle(range(len(index) + 1))
)
def make_rotating_asset_info(num_assets,
first_start,
frequency,
periods_between_starts,
asset_lifetime):
"""
Create a DataFrame representing lifetimes of assets that are constantly
rotating in and out of existence.
Parameters
----------
num_assets : int
How many assets to create.
first_start : pd.Timestamp
The start date for the first asset.
frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
Frequency used to interpret next two arguments.
periods_between_starts : int
Create a new asset every `frequency` * `periods_between_new`
asset_lifetime : int
Each asset exists for `frequency` * `asset_lifetime` days.
Returns
-------
info : pd.DataFrame
DataFrame representing newly-created assets.
"""
return pd.DataFrame(
{
'sid': range(num_assets),
'symbol': [chr(ord('A') + i) for i in range(num_assets)],
'asset_type': ['equity'] * num_assets,
# Start a new asset every `periods_between_starts` days.
'start_date': pd.date_range(
first_start,
freq=(periods_between_starts * frequency),
periods=num_assets,
),
# Each asset lasts for `asset_lifetime` days.
'end_date': pd.date_range(
first_start + (asset_lifetime * frequency),
freq=(periods_between_starts * frequency),
periods=num_assets,
),
'exchange': 'TEST',
}
)
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
"""
Create a DataFrame representing assets that exist for the full duration
between `start_date` and `end_date`.
Parameters
----------
assets : array-like
start_date : pd.Timestamp
end_date : pd.Timestamp
symbols : list, optional
Symbols to use for the assets.
If not provided, symbols are generated from the sequence 'A', 'B', ...
Returns
-------
info : pd.DataFrame
DataFrame representing newly-created assets.
"""
num_assets = len(assets)
if symbols is None:
symbols = [chr(ord('A') + i) for i in range(num_assets)]
return pd.DataFrame(
{
'sid': assets,
'symbol': symbols,
'asset_type': ['equity'] * num_assets,
'start_date': [start_date] * num_assets,
'end_date': [end_date] * num_assets,
'exchange': 'TEST',
}
)
def check_arrays(left, right, err_msg='', verbose=True):
"""
Wrapper around np.assert_array_equal that also verifies that inputs are
ndarrays.
See Also
--------
np.assert_array_equal
"""
if type(left) != type(right):
raise AssertionError("%s != %s" % (type(left), type(right)))
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)
class UnexpectedAttributeAccess(Exception):
pass
class ExplodingObject(object):
"""
Object that will raise an exception on any attribute access.
Useful for verifying that an object is never touched during a
function/method call.
"""
def __getattribute__(self, name):
raise UnexpectedAttributeAccess(name)