mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 21:09:25 +08:00
6e8a4b8144
- Add an `ascending=True` keyword to `rank()`. - Add `top(N)` and `bottom(N)` methods to Factor. These return Filters that pass the top and bottom N elements each day. - Add a slightly faster path for rank(method='ordinal'). I had originally thought the fast path was 2-3x faster because I had my benchmark data axes flipped. The actual speedup is only 5-10%, which means it probably wasn't worth the effort to Cythonize...but we have a slightly faster version now so we might as well use it. - Refactor test_filter and test_factor to make it easier to implement and test transformations on factors. These tests now subclass BaseFFCTestCase, which provides facilities for passing a dict of terms and an "initial_workspace", the values for which are used by SimpleFFCEngine rather than needing to manually manage the inputs and outputs of each term.
340 lines
9.3 KiB
Python
340 lines
9.3 KiB
Python
from contextlib import contextmanager
|
|
from itertools import (
|
|
product,
|
|
)
|
|
from logbook import FileHandler
|
|
from mock import patch
|
|
from numpy.testing import assert_array_equal
|
|
import operator
|
|
from zipline.finance.blotter import ORDER_STATUS
|
|
from zipline.utils import security_list
|
|
from six import (
|
|
itervalues,
|
|
)
|
|
from six.moves import filter
|
|
|
|
import os
|
|
import pandas as pd
|
|
import shutil
|
|
import tempfile
|
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC')
|
|
|
|
|
|
def seconds_to_timestamp(seconds):
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC')
|
|
|
|
|
|
def to_utc(time_str):
|
|
"""Convert a string in US/Eastern time to UTC"""
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
|
|
|
|
|
|
def str_to_seconds(s):
|
|
"""
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC.
|
|
|
|
>>> from pandas import Timestamp
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
|
|
1388534400.0
|
|
>>> str_to_seconds('2014-01-01')
|
|
1388534400
|
|
"""
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
|
|
|
|
|
|
def setup_logger(test, path='test.log'):
|
|
test.log_handler = FileHandler(path)
|
|
test.log_handler.push_application()
|
|
|
|
|
|
def teardown_logger(test):
|
|
test.log_handler.pop_application()
|
|
test.log_handler.close()
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
# confirm that all orders were filled.
|
|
# iterate over the output updates, overwriting
|
|
# orders when they are updated. Then check the status on all.
|
|
orders_by_id = {}
|
|
for update in output:
|
|
if 'daily_perf' in update:
|
|
if 'orders' in update['daily_perf']:
|
|
for order in update['daily_perf']['orders']:
|
|
orders_by_id[order['id']] = order
|
|
|
|
for order in itervalues(orders_by_id):
|
|
test.assertEqual(
|
|
order['status'],
|
|
ORDER_STATUS.FILLED,
|
|
"")
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
def __next__(self):
|
|
5 / 0
|
|
|
|
|
|
@contextmanager
|
|
def security_list_copy():
|
|
old_dir = security_list.SECURITY_LISTS_DIR
|
|
new_dir = tempfile.mkdtemp()
|
|
try:
|
|
for subdir in os.listdir(old_dir):
|
|
shutil.copytree(os.path.join(old_dir, subdir),
|
|
os.path.join(new_dir, subdir))
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
|
|
patch.object(security_list, 'using_copy', True,
|
|
create=True):
|
|
yield
|
|
finally:
|
|
shutil.rmtree(new_dir, True)
|
|
|
|
|
|
def add_security_data(adds, deletes):
|
|
if not hasattr(security_list, 'using_copy'):
|
|
raise Exception('add_security_data must be used within '
|
|
'security_list_copy context')
|
|
directory = os.path.join(
|
|
security_list.SECURITY_LISTS_DIR,
|
|
"leveraged_etf_list/20150127/20150125"
|
|
)
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
del_path = os.path.join(directory, "delete")
|
|
with open(del_path, 'w') as f:
|
|
for sym in deletes:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
add_path = os.path.join(directory, "add")
|
|
with open(add_path, 'w') as f:
|
|
for sym in adds:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
|
|
|
|
def all_pairs_matching_predicate(values, pred):
|
|
"""
|
|
Return an iterator of all pairs, (v0, v1) from values such that
|
|
|
|
`pred(v0, v1) == True`
|
|
|
|
Parameters
|
|
----------
|
|
values : iterable
|
|
pred : function
|
|
|
|
Returns
|
|
-------
|
|
pairs_iterator : generator
|
|
Generator yielding pairs matching `pred`.
|
|
|
|
Examples
|
|
--------
|
|
>>> from zipline.utils.test_utils import all_pairs_matching_predicate
|
|
>>> from operator import eq, lt
|
|
>>> list(all_pairs_matching_predicate(range(5), eq))
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
|
>>> list(all_pairs_matching_predicate("abcd", lt))
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
|
|
"""
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2))
|
|
|
|
|
|
def product_upper_triangle(values, include_diagonal=False):
|
|
"""
|
|
Return an iterator over pairs, (v0, v1), drawn from values.
|
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1.
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1.
|
|
"""
|
|
return all_pairs_matching_predicate(
|
|
values,
|
|
operator.le if include_diagonal else operator.lt,
|
|
)
|
|
|
|
|
|
def all_subindices(index):
|
|
"""
|
|
Return all valid sub-indices of a pandas Index.
|
|
"""
|
|
return (
|
|
index[start:stop]
|
|
for start, stop in product_upper_triangle(range(len(index) + 1))
|
|
)
|
|
|
|
|
|
def make_rotating_asset_info(num_assets,
|
|
first_start,
|
|
frequency,
|
|
periods_between_starts,
|
|
asset_lifetime):
|
|
"""
|
|
Create a DataFrame representing lifetimes of assets that are constantly
|
|
rotating in and out of existence.
|
|
|
|
Parameters
|
|
----------
|
|
num_assets : int
|
|
How many assets to create.
|
|
first_start : pd.Timestamp
|
|
The start date for the first asset.
|
|
frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
|
|
Frequency used to interpret next two arguments.
|
|
periods_between_starts : int
|
|
Create a new asset every `frequency` * `periods_between_new`
|
|
asset_lifetime : int
|
|
Each asset exists for `frequency` * `asset_lifetime` days.
|
|
|
|
Returns
|
|
-------
|
|
info : pd.DataFrame
|
|
DataFrame representing newly-created assets.
|
|
"""
|
|
return pd.DataFrame(
|
|
{
|
|
'sid': range(num_assets),
|
|
'symbol': [chr(ord('A') + i) for i in range(num_assets)],
|
|
'asset_type': ['equity'] * num_assets,
|
|
# Start a new asset every `periods_between_starts` days.
|
|
'start_date': pd.date_range(
|
|
first_start,
|
|
freq=(periods_between_starts * frequency),
|
|
periods=num_assets,
|
|
),
|
|
# Each asset lasts for `asset_lifetime` days.
|
|
'end_date': pd.date_range(
|
|
first_start + (asset_lifetime * frequency),
|
|
freq=(periods_between_starts * frequency),
|
|
periods=num_assets,
|
|
),
|
|
'exchange': 'TEST',
|
|
}
|
|
)
|
|
|
|
|
|
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
|
|
"""
|
|
Create a DataFrame representing assets that exist for the full duration
|
|
between `start_date` and `end_date`.
|
|
|
|
Parameters
|
|
----------
|
|
assets : array-like
|
|
start_date : pd.DatetimeIndex
|
|
end_date : pd.DatetimeIndex
|
|
symbols : list, optional
|
|
Symbols to use for the assets.
|
|
If not provided, symbols are generated from upper-case letters.
|
|
|
|
Returns
|
|
-------
|
|
info : pd.DataFrame
|
|
DataFrame representing newly-created assets.
|
|
"""
|
|
num_assets = len(assets)
|
|
if symbols is None:
|
|
symbols = [chr(ord('A') + i) for i in range(num_assets)]
|
|
return pd.DataFrame(
|
|
{
|
|
'sid': assets,
|
|
'symbol': symbols,
|
|
'asset_type': ['equity'] * num_assets,
|
|
'start_date': [start_date] * num_assets,
|
|
'end_date': [end_date] * num_assets,
|
|
'exchange': 'TEST',
|
|
}
|
|
)
|
|
|
|
|
|
def check_arrays(left, right, err_msg='', verbose=True):
|
|
"""
|
|
Wrapper around np.assert_array_equal that also verifies that inputs are
|
|
ndarrays.
|
|
|
|
See Also
|
|
--------
|
|
np.assert_array_equal
|
|
"""
|
|
if type(left) != type(right):
|
|
raise AssertionError("%s != %s" % (type(left), type(right)))
|
|
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)
|
|
|
|
|
|
class UnexpectedAttributeAccess(Exception):
|
|
pass
|
|
|
|
|
|
class ExplodingObject(object):
|
|
"""
|
|
Object that will raise an exception on any attribute access.
|
|
|
|
Useful for verifying that an object is never touched during a
|
|
function/method call.
|
|
"""
|
|
def __getattribute__(self, name):
|
|
raise UnexpectedAttributeAccess(name)
|