mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:30:28 +08:00
1383 lines
40 KiB
Python
1383 lines
40 KiB
Python
from abc import ABCMeta, abstractmethod, abstractproperty
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
import gzip
|
|
from inspect import getargspec
|
|
from itertools import (
|
|
combinations,
|
|
count,
|
|
product,
|
|
)
|
|
import operator
|
|
import os
|
|
from os.path import abspath, dirname, join, realpath
|
|
import shutil
|
|
import tempfile
|
|
|
|
from logbook import TestHandler
|
|
from mock import patch
|
|
from nose.tools import nottest
|
|
from numpy.testing import assert_allclose, assert_array_equal
|
|
import pandas as pd
|
|
from six import itervalues, iteritems, with_metaclass
|
|
from six.moves import filter, map
|
|
from sqlalchemy import create_engine
|
|
from testfixtures import TempDirectory
|
|
from toolz import concat
|
|
|
|
from zipline.assets import AssetFinder, AssetDBWriter
|
|
from zipline.assets.synthetic import make_simple_equity_info
|
|
from zipline.data.data_portal import DataPortal
|
|
from zipline.data.minute_bars import (
|
|
BcolzMinuteBarReader,
|
|
BcolzMinuteBarWriter,
|
|
US_EQUITIES_MINUTES_PER_DAY
|
|
)
|
|
from zipline.data.us_equity_pricing import (
|
|
BcolzDailyBarReader,
|
|
BcolzDailyBarWriter,
|
|
SQLiteAdjustmentWriter,
|
|
)
|
|
from zipline.finance.trading import TradingEnvironment
|
|
from zipline.finance.order import ORDER_STATUS
|
|
from zipline.lib.labelarray import LabelArray
|
|
from zipline.pipeline.engine import SimplePipelineEngine
|
|
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
|
from zipline.utils import security_list
|
|
from zipline.utils.input_validation import expect_dimensions
|
|
from zipline.utils.sentinel import sentinel
|
|
from zipline.utils.tradingcalendar import trading_days
|
|
from zipline.utils.calendars import default_nyse_schedule
|
|
import numpy as np
|
|
from numpy import float64
|
|
|
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC')
|
|
|
|
|
|
def seconds_to_timestamp(seconds):
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC')
|
|
|
|
|
|
def to_utc(time_str):
|
|
"""Convert a string in US/Eastern time to UTC"""
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
|
|
|
|
|
|
def str_to_seconds(s):
|
|
"""
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC.
|
|
|
|
>>> from pandas import Timestamp
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
|
|
1388534400.0
|
|
>>> str_to_seconds('2014-01-01')
|
|
1388534400
|
|
"""
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def check_algo_results(test,
|
|
results,
|
|
expected_transactions_count=None,
|
|
expected_order_count=None,
|
|
expected_positions_count=None,
|
|
sid=None):
|
|
|
|
if expected_transactions_count is not None:
|
|
txns = flatten_list(results["transactions"])
|
|
test.assertEqual(expected_transactions_count, len(txns))
|
|
|
|
if expected_positions_count is not None:
|
|
raise NotImplementedError
|
|
|
|
if expected_order_count is not None:
|
|
# de-dup orders on id, because orders are put back into perf packets
|
|
# whenever they a txn is filled
|
|
orders = set([order['id'] for order in
|
|
flatten_list(results["orders"])])
|
|
|
|
test.assertEqual(expected_order_count, len(orders))
|
|
|
|
|
|
def flatten_list(list):
|
|
return [item for sublist in list for item in sublist]
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
# confirm that all orders were filled.
|
|
# iterate over the output updates, overwriting
|
|
# orders when they are updated. Then check the status on all.
|
|
orders_by_id = {}
|
|
for update in output:
|
|
if 'daily_perf' in update:
|
|
if 'orders' in update['daily_perf']:
|
|
for order in update['daily_perf']['orders']:
|
|
orders_by_id[order['id']] = order
|
|
|
|
for order in itervalues(orders_by_id):
|
|
test.assertEqual(
|
|
order['status'],
|
|
ORDER_STATUS.FILLED,
|
|
"")
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
class ExceptionSource(object):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_hash(self):
|
|
return "ExceptionSource"
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
5 / 0
|
|
|
|
def __next__(self):
|
|
5 / 0
|
|
|
|
|
|
@contextmanager
|
|
def security_list_copy():
|
|
old_dir = security_list.SECURITY_LISTS_DIR
|
|
new_dir = tempfile.mkdtemp()
|
|
try:
|
|
for subdir in os.listdir(old_dir):
|
|
shutil.copytree(os.path.join(old_dir, subdir),
|
|
os.path.join(new_dir, subdir))
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
|
|
patch.object(security_list, 'using_copy', True,
|
|
create=True):
|
|
yield
|
|
finally:
|
|
shutil.rmtree(new_dir, True)
|
|
|
|
|
|
def add_security_data(adds, deletes):
|
|
if not hasattr(security_list, 'using_copy'):
|
|
raise Exception('add_security_data must be used within '
|
|
'security_list_copy context')
|
|
directory = os.path.join(
|
|
security_list.SECURITY_LISTS_DIR,
|
|
"leveraged_etf_list/20150127/20150125"
|
|
)
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
del_path = os.path.join(directory, "delete")
|
|
with open(del_path, 'w') as f:
|
|
for sym in deletes:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
add_path = os.path.join(directory, "add")
|
|
with open(add_path, 'w') as f:
|
|
for sym in adds:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
|
|
|
|
def all_pairs_matching_predicate(values, pred):
|
|
"""
|
|
Return an iterator of all pairs, (v0, v1) from values such that
|
|
|
|
`pred(v0, v1) == True`
|
|
|
|
Parameters
|
|
----------
|
|
values : iterable
|
|
pred : function
|
|
|
|
Returns
|
|
-------
|
|
pairs_iterator : generator
|
|
Generator yielding pairs matching `pred`.
|
|
|
|
Examples
|
|
--------
|
|
>>> from zipline.testing import all_pairs_matching_predicate
|
|
>>> from operator import eq, lt
|
|
>>> list(all_pairs_matching_predicate(range(5), eq))
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
|
>>> list(all_pairs_matching_predicate("abcd", lt))
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
|
|
"""
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2))
|
|
|
|
|
|
def product_upper_triangle(values, include_diagonal=False):
|
|
"""
|
|
Return an iterator over pairs, (v0, v1), drawn from values.
|
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1.
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1.
|
|
"""
|
|
return all_pairs_matching_predicate(
|
|
values,
|
|
operator.le if include_diagonal else operator.lt,
|
|
)
|
|
|
|
|
|
def all_subindices(index):
|
|
"""
|
|
Return all valid sub-indices of a pandas Index.
|
|
"""
|
|
return (
|
|
index[start:stop]
|
|
for start, stop in product_upper_triangle(range(len(index) + 1))
|
|
)
|
|
|
|
|
|
def chrange(start, stop):
|
|
"""
|
|
Construct an iterable of length-1 strings beginning with `start` and ending
|
|
with `stop`.
|
|
|
|
Parameters
|
|
----------
|
|
start : str
|
|
The first character.
|
|
stop : str
|
|
The last character.
|
|
|
|
Returns
|
|
-------
|
|
chars: iterable[str]
|
|
Iterable of strings beginning with start and ending with stop.
|
|
|
|
Example
|
|
-------
|
|
>>> chrange('A', 'C')
|
|
['A', 'B', 'C']
|
|
"""
|
|
return list(map(chr, range(ord(start), ord(stop) + 1)))
|
|
|
|
|
|
def make_trade_data_for_asset_info(dates,
|
|
asset_info,
|
|
price_start,
|
|
price_step_by_date,
|
|
price_step_by_sid,
|
|
volume_start,
|
|
volume_step_by_date,
|
|
volume_step_by_sid,
|
|
frequency,
|
|
writer=None):
|
|
"""
|
|
Convert the asset info dataframe into a dataframe of trade data for each
|
|
sid, and write to the writer if provided. Write NaNs for locations where
|
|
assets did not exist. Return a dict of the dataframes, keyed by sid.
|
|
"""
|
|
trade_data = {}
|
|
sids = asset_info.index
|
|
|
|
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
|
|
price_date_deltas = (np.arange(len(dates), dtype=float64) *
|
|
price_step_by_date)
|
|
prices = (price_sid_deltas + price_date_deltas[:, None]) + price_start
|
|
|
|
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
|
|
volume_date_deltas = np.arange(len(dates)) * volume_step_by_date
|
|
volumes = (volume_sid_deltas + volume_date_deltas[:, None]) + volume_start
|
|
|
|
for j, sid in enumerate(sids):
|
|
start_date, end_date = asset_info.loc[sid, ['start_date', 'end_date']]
|
|
# Normalize here so the we still generate non-NaN values on the minutes
|
|
# for an asset's last trading day.
|
|
for i, date in enumerate(dates.normalize()):
|
|
if not (start_date <= date <= end_date):
|
|
prices[i, j] = 0
|
|
volumes[i, j] = 0
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": prices[:, j],
|
|
"high": prices[:, j],
|
|
"low": prices[:, j],
|
|
"close": prices[:, j],
|
|
"volume": volumes[:, j],
|
|
},
|
|
index=dates,
|
|
)
|
|
|
|
if writer:
|
|
writer.write_sid(sid, df)
|
|
|
|
trade_data[sid] = df
|
|
|
|
return trade_data
|
|
|
|
|
|
def check_allclose(actual,
|
|
desired,
|
|
rtol=1e-07,
|
|
atol=0,
|
|
err_msg='',
|
|
verbose=True):
|
|
"""
|
|
Wrapper around np.testing.assert_allclose that also verifies that inputs
|
|
are ndarrays.
|
|
|
|
See Also
|
|
--------
|
|
np.assert_allclose
|
|
"""
|
|
if type(actual) != type(desired):
|
|
raise AssertionError("%s != %s" % (type(actual), type(desired)))
|
|
return assert_allclose(
|
|
actual,
|
|
desired,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
err_msg=err_msg,
|
|
verbose=verbose,
|
|
)
|
|
|
|
|
|
def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
|
|
"""
|
|
Wrapper around np.testing.assert_array_equal that also verifies that inputs
|
|
are ndarrays.
|
|
|
|
See Also
|
|
--------
|
|
np.assert_array_equal
|
|
"""
|
|
assert type(x) == type(y), "{x} != {y}".format(x=type(x), y=type(y))
|
|
assert x.dtype == y.dtype, "{x.dtype} != {y.dtype}".format(x=x, y=y)
|
|
|
|
if isinstance(x, LabelArray):
|
|
# Check that both arrays have missing values in the same locations...
|
|
assert_array_equal(
|
|
x.is_missing(),
|
|
y.is_missing(),
|
|
err_msg=err_msg,
|
|
verbose=verbose,
|
|
)
|
|
# ...then check the actual values as well.
|
|
x = x.as_string_array()
|
|
y = y.as_string_array()
|
|
|
|
return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose)
|
|
|
|
|
|
class UnexpectedAttributeAccess(Exception):
|
|
pass
|
|
|
|
|
|
class ExplodingObject(object):
|
|
"""
|
|
Object that will raise an exception on any attribute access.
|
|
|
|
Useful for verifying that an object is never touched during a
|
|
function/method call.
|
|
"""
|
|
def __getattribute__(self, name):
|
|
raise UnexpectedAttributeAccess(name)
|
|
|
|
|
|
def write_minute_data(trading_schedule, tempdir, minutes, sids):
|
|
write_bcolz_minute_data(
|
|
trading_schedule,
|
|
trading_schedule.execution_days_in_range(minutes[0], minutes[-1]),
|
|
tempdir.path,
|
|
create_minute_bar_data(minutes, sids),
|
|
)
|
|
return tempdir.path
|
|
|
|
|
|
def create_minute_bar_data(minutes, sids):
|
|
length = len(minutes)
|
|
for sid_idx, sid in enumerate(sids):
|
|
yield sid, pd.DataFrame(
|
|
{
|
|
'open': np.arange(length) + 10 + sid_idx,
|
|
'high': np.arange(length) + 15 + sid_idx,
|
|
'low': np.arange(length) + 8 + sid_idx,
|
|
'close': np.arange(length) + 10 + sid_idx,
|
|
'volume': np.arange(length) + 100 + sid_idx,
|
|
},
|
|
index=minutes,
|
|
)
|
|
|
|
|
|
def create_daily_bar_data(trading_days, sids):
|
|
length = len(trading_days)
|
|
for sid_idx, sid in enumerate(sids):
|
|
yield sid, pd.DataFrame(
|
|
{
|
|
"open": (np.array(range(10, 10 + length)) + sid_idx),
|
|
"high": (np.array(range(15, 15 + length)) + sid_idx),
|
|
"low": (np.array(range(8, 8 + length)) + sid_idx),
|
|
"close": (np.array(range(10, 10 + length)) + sid_idx),
|
|
"volume": np.array(range(100, 100 + length)) + sid_idx,
|
|
"day": [day.value for day in trading_days]
|
|
},
|
|
index=trading_days,
|
|
)
|
|
|
|
|
|
def write_daily_data(tempdir, sim_params, sids):
|
|
path = os.path.join(tempdir.path, "testdaily.bcolz")
|
|
BcolzDailyBarWriter(path, sim_params.trading_days).write(
|
|
create_daily_bar_data(sim_params.trading_days, sids),
|
|
)
|
|
|
|
return path
|
|
|
|
|
|
def create_data_portal(env, tempdir, sim_params, sids, trading_schedule,
|
|
adjustment_reader=None):
|
|
if sim_params.data_frequency == "daily":
|
|
daily_path = write_daily_data(tempdir, sim_params, sids)
|
|
|
|
equity_daily_reader = BcolzDailyBarReader(daily_path)
|
|
|
|
return DataPortal(
|
|
env, trading_schedule,
|
|
first_trading_day=equity_daily_reader.first_trading_day,
|
|
equity_daily_reader=equity_daily_reader,
|
|
adjustment_reader=adjustment_reader
|
|
)
|
|
else:
|
|
minutes = trading_schedule.execution_minutes_for_days_in_range(
|
|
sim_params.first_open,
|
|
sim_params.last_close
|
|
)
|
|
|
|
minute_path = write_minute_data(trading_schedule, tempdir, minutes,
|
|
sids)
|
|
|
|
equity_minute_reader = BcolzMinuteBarReader(minute_path)
|
|
|
|
return DataPortal(
|
|
env, trading_schedule,
|
|
first_trading_day=equity_minute_reader.first_trading_day,
|
|
equity_minute_reader=equity_minute_reader,
|
|
adjustment_reader=adjustment_reader
|
|
)
|
|
|
|
|
|
def write_bcolz_minute_data(trading_schedule, days, path, data):
|
|
market_opens = trading_schedule.schedule.loc[days].market_open
|
|
market_closes = trading_schedule.schedule.loc[days].market_close
|
|
|
|
BcolzMinuteBarWriter(
|
|
days[0],
|
|
path,
|
|
market_opens,
|
|
market_closes,
|
|
US_EQUITIES_MINUTES_PER_DAY
|
|
).write(data)
|
|
|
|
|
|
def create_minute_df_for_asset(trading_schedule,
|
|
start_dt,
|
|
end_dt,
|
|
interval=1,
|
|
start_val=1,
|
|
minute_blacklist=None):
|
|
|
|
asset_minutes = trading_schedule.execution_minutes_for_days_in_range(
|
|
start_dt, end_dt
|
|
)
|
|
minutes_count = len(asset_minutes)
|
|
minutes_arr = np.array(range(start_val, start_val + minutes_count))
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": minutes_arr + 1,
|
|
"high": minutes_arr + 2,
|
|
"low": minutes_arr - 1,
|
|
"close": minutes_arr,
|
|
"volume": 100 * minutes_arr,
|
|
},
|
|
index=asset_minutes,
|
|
)
|
|
|
|
if interval > 1:
|
|
counter = 0
|
|
while counter < len(minutes_arr):
|
|
df[counter:(counter + interval - 1)] = 0
|
|
counter += interval
|
|
|
|
if minute_blacklist is not None:
|
|
for minute in minute_blacklist:
|
|
df.loc[minute] = 0
|
|
|
|
return df
|
|
|
|
|
|
def create_daily_df_for_asset(trading_schedule, start_day, end_day,
|
|
interval=1):
|
|
days = trading_schedule.execution_days_in_range(start_day, end_day)
|
|
days_count = len(days)
|
|
days_arr = np.arange(days_count) + 2
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": days_arr + 1,
|
|
"high": days_arr + 2,
|
|
"low": days_arr - 1,
|
|
"close": days_arr,
|
|
"volume": days_arr * 100,
|
|
},
|
|
index=days,
|
|
)
|
|
|
|
if interval > 1:
|
|
# only keep every 'interval' rows
|
|
for idx, _ in enumerate(days_arr):
|
|
if (idx + 1) % interval != 0:
|
|
df["open"].iloc[idx] = 0
|
|
df["high"].iloc[idx] = 0
|
|
df["low"].iloc[idx] = 0
|
|
df["close"].iloc[idx] = 0
|
|
df["volume"].iloc[idx] = 0
|
|
|
|
return df
|
|
|
|
|
|
def trades_by_sid_to_dfs(trades_by_sid, index):
|
|
for sidint, trades in iteritems(trades_by_sid):
|
|
opens = []
|
|
highs = []
|
|
lows = []
|
|
closes = []
|
|
volumes = []
|
|
for trade in trades:
|
|
opens.append(trade["open_price"])
|
|
highs.append(trade["high"])
|
|
lows.append(trade["low"])
|
|
closes.append(trade["close_price"])
|
|
volumes.append(trade["volume"])
|
|
|
|
yield sidint, pd.DataFrame(
|
|
{
|
|
"open": opens,
|
|
"high": highs,
|
|
"low": lows,
|
|
"close": closes,
|
|
"volume": volumes,
|
|
},
|
|
index=index,
|
|
)
|
|
|
|
|
|
def create_data_portal_from_trade_history(env, trading_schedule, tempdir,
|
|
sim_params, trades_by_sid):
|
|
if sim_params.data_frequency == "daily":
|
|
path = os.path.join(tempdir.path, "testdaily.bcolz")
|
|
BcolzDailyBarWriter(path, sim_params.trading_days).write(
|
|
trades_by_sid_to_dfs(trades_by_sid, sim_params.trading_days),
|
|
)
|
|
|
|
equity_daily_reader = BcolzDailyBarReader(path)
|
|
|
|
return DataPortal(
|
|
env, trading_schedule,
|
|
first_trading_day=equity_daily_reader.first_trading_day,
|
|
equity_daily_reader=equity_daily_reader,
|
|
)
|
|
else:
|
|
minutes = trading_schedule.execution_minutes_for_days_in_range(
|
|
sim_params.first_open,
|
|
sim_params.last_close
|
|
)
|
|
|
|
length = len(minutes)
|
|
assets = {}
|
|
|
|
for sidint, trades in iteritems(trades_by_sid):
|
|
opens = np.zeros(length)
|
|
highs = np.zeros(length)
|
|
lows = np.zeros(length)
|
|
closes = np.zeros(length)
|
|
volumes = np.zeros(length)
|
|
|
|
for trade in trades:
|
|
# put them in the right place
|
|
idx = minutes.searchsorted(trade.dt)
|
|
|
|
opens[idx] = trade.open_price * 1000
|
|
highs[idx] = trade.high * 1000
|
|
lows[idx] = trade.low * 1000
|
|
closes[idx] = trade.close_price * 1000
|
|
volumes[idx] = trade.volume
|
|
|
|
assets[sidint] = pd.DataFrame({
|
|
"open": opens,
|
|
"high": highs,
|
|
"low": lows,
|
|
"close": closes,
|
|
"volume": volumes,
|
|
"dt": minutes
|
|
}).set_index("dt")
|
|
|
|
write_bcolz_minute_data(
|
|
trading_schedule,
|
|
trading_schedule.execution_days_in_range(
|
|
sim_params.first_open,
|
|
sim_params.last_close
|
|
),
|
|
tempdir.path,
|
|
assets
|
|
)
|
|
|
|
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
|
|
|
|
return DataPortal(
|
|
env, trading_schedule,
|
|
first_trading_day=equity_minute_reader.first_trading_day,
|
|
equity_minute_reader=equity_minute_reader,
|
|
)
|
|
|
|
|
|
class FakeDataPortal(DataPortal):
|
|
|
|
def __init__(self, env=None, trading_schedule=default_nyse_schedule,
|
|
first_trading_day=None):
|
|
if env is None:
|
|
env = TradingEnvironment()
|
|
|
|
super(FakeDataPortal, self).__init__(env, trading_schedule,
|
|
first_trading_day)
|
|
|
|
def get_spot_value(self, asset, field, dt, data_frequency):
|
|
if field == "volume":
|
|
return 100
|
|
else:
|
|
return 1.0
|
|
|
|
def get_history_window(self, assets, end_dt, bar_count, frequency, field,
|
|
ffill=True):
|
|
if frequency == "1d":
|
|
end_idx = \
|
|
self.trading_schedule.all_execution_days.searchsorted(end_dt)
|
|
days = self.trading_schedule.all_execution_days[
|
|
(end_idx - bar_count + 1):(end_idx + 1)
|
|
]
|
|
|
|
df = pd.DataFrame(
|
|
np.full((bar_count, len(assets)), 100),
|
|
index=days,
|
|
columns=assets
|
|
)
|
|
|
|
return df
|
|
|
|
|
|
class FetcherDataPortal(DataPortal):
|
|
"""
|
|
Mock dataportal that returns fake data for history and non-fetcher
|
|
spot value.
|
|
"""
|
|
def __init__(self, env, trading_schedule, first_trading_day=None):
|
|
super(FetcherDataPortal, self).__init__(env, trading_schedule,
|
|
first_trading_day)
|
|
|
|
def get_spot_value(self, asset, field, dt, data_frequency):
|
|
# if this is a fetcher field, exercise the regular code path
|
|
if self._is_extra_source(asset, field, self._augmented_sources_map):
|
|
return super(FetcherDataPortal, self).get_spot_value(
|
|
asset, field, dt, data_frequency)
|
|
|
|
# otherwise just return a fixed value
|
|
return int(asset)
|
|
|
|
def _get_daily_window_for_sid(self, asset, field, days_in_window,
|
|
extra_slot=True):
|
|
return np.arange(days_in_window, dtype=np.float64)
|
|
|
|
def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
|
|
return np.arange(minutes_for_window, dtype=np.float64)
|
|
|
|
|
|
class tmp_assets_db(object):
|
|
"""Create a temporary assets sqlite database.
|
|
This is meant to be used as a context manager.
|
|
|
|
Parameters
|
|
----------
|
|
**frames
|
|
The frames to pass to the AssetDBWriter.
|
|
By default this maps equities:
|
|
('A', 'B', 'C') -> map(ord, 'ABC')
|
|
|
|
See Also
|
|
--------
|
|
empty_assets_db
|
|
tmp_asset_finder
|
|
"""
|
|
_default_equities = sentinel('_default_equities')
|
|
|
|
def __init__(self, equities=_default_equities, **frames):
|
|
self._eng = None
|
|
if equities is self._default_equities:
|
|
equities = make_simple_equity_info(
|
|
list(map(ord, 'ABC')),
|
|
pd.Timestamp(0),
|
|
pd.Timestamp('2015'),
|
|
)
|
|
|
|
frames['equities'] = equities
|
|
self._frames = frames
|
|
self._eng = None # set in enter and exit
|
|
|
|
def __enter__(self):
|
|
self._eng = eng = create_engine('sqlite://')
|
|
AssetDBWriter(eng).write(**self._frames)
|
|
return eng
|
|
|
|
def __exit__(self, *excinfo):
|
|
assert self._eng is not None, '_eng was not set in __enter__'
|
|
self._eng.dispose()
|
|
self._eng = None
|
|
|
|
|
|
def empty_assets_db():
|
|
"""Context manager for creating an empty assets db.
|
|
|
|
See Also
|
|
--------
|
|
tmp_assets_db
|
|
"""
|
|
return tmp_assets_db(equities=None)
|
|
|
|
|
|
class tmp_asset_finder(tmp_assets_db):
|
|
"""Create a temporary asset finder using an in memory sqlite db.
|
|
|
|
Parameters
|
|
----------
|
|
finder_cls : type, optional
|
|
The type of asset finder to create from the assets db.
|
|
**frames
|
|
Forwarded to ``tmp_assets_db``.
|
|
|
|
See Also
|
|
--------
|
|
tmp_assets_db
|
|
"""
|
|
def __init__(self, finder_cls=AssetFinder, **frames):
|
|
self._finder_cls = finder_cls
|
|
super(tmp_asset_finder, self).__init__(**frames)
|
|
|
|
def __enter__(self):
|
|
return self._finder_cls(super(tmp_asset_finder, self).__enter__())
|
|
|
|
|
|
def empty_asset_finder():
|
|
"""Context manager for creating an empty asset finder.
|
|
|
|
See Also
|
|
--------
|
|
empty_assets_db
|
|
tmp_assets_db
|
|
tmp_asset_finder
|
|
"""
|
|
return tmp_asset_finder(equities=None)
|
|
|
|
|
|
class tmp_trading_env(tmp_asset_finder):
|
|
"""Create a temporary trading environment.
|
|
|
|
Parameters
|
|
----------
|
|
finder_cls : type, optional
|
|
The type of asset finder to create from the assets db.
|
|
**frames
|
|
Forwarded to ``tmp_assets_db``.
|
|
|
|
See Also
|
|
--------
|
|
empty_trading_env
|
|
tmp_asset_finder
|
|
"""
|
|
def __enter__(self):
|
|
return TradingEnvironment(
|
|
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
|
|
)
|
|
|
|
|
|
def empty_trading_env():
|
|
return tmp_trading_env(equities=None)
|
|
|
|
|
|
class SubTestFailures(AssertionError):
|
|
def __init__(self, *failures):
|
|
self.failures = failures
|
|
|
|
def __str__(self):
|
|
return 'failures:\n %s' % '\n '.join(
|
|
'\n '.join((
|
|
', '.join('%s=%r' % item for item in scope.items()),
|
|
'%s: %s' % (type(exc).__name__, exc),
|
|
)) for scope, exc in self.failures,
|
|
)
|
|
|
|
|
|
def subtest(iterator, *_names):
|
|
"""
|
|
Construct a subtest in a unittest.
|
|
|
|
Consider using ``zipline.testing.parameter_space`` when subtests
|
|
are constructed over a single input or over the cross-product of multiple
|
|
inputs.
|
|
|
|
``subtest`` works by decorating a function as a subtest. The decorated
|
|
function will be run by iterating over the ``iterator`` and *unpacking the
|
|
values into the function. If any of the runs fail, the result will be put
|
|
into a set and the rest of the tests will be run. Finally, if any failed,
|
|
all of the results will be dumped as one failure.
|
|
|
|
Parameters
|
|
----------
|
|
iterator : iterable[iterable]
|
|
The iterator of arguments to pass to the function.
|
|
*name : iterator[str]
|
|
The names to use for each element of ``iterator``. These will be used
|
|
to print the scope when a test fails. If not provided, it will use the
|
|
integer index of the value as the name.
|
|
|
|
Examples
|
|
--------
|
|
|
|
::
|
|
|
|
class MyTest(TestCase):
|
|
def test_thing(self):
|
|
# Example usage inside another test.
|
|
@subtest(([n] for n in range(100000)), 'n')
|
|
def subtest(n):
|
|
self.assertEqual(n % 2, 0, 'n was not even')
|
|
subtest()
|
|
|
|
@subtest(([n] for n in range(100000)), 'n')
|
|
def test_decorated_function(self, n):
|
|
# Example usage to parameterize an entire function.
|
|
self.assertEqual(n % 2, 1, 'n was not odd')
|
|
|
|
Notes
|
|
-----
|
|
We use this when we:
|
|
|
|
* Will never want to run each parameter individually.
|
|
* Have a large parameter space we are testing
|
|
(see tests/utils/test_events.py).
|
|
|
|
``nose_parameterized.expand`` will create a test for each parameter
|
|
combination which bloats the test output and makes the travis pages slow.
|
|
|
|
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
|
|
nose2 do not support ``addSubTest``.
|
|
|
|
See Also
|
|
--------
|
|
zipline.testing.parameter_space
|
|
"""
|
|
def dec(f):
|
|
@wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
names = _names
|
|
failures = []
|
|
for scope in iterator:
|
|
scope = tuple(scope)
|
|
try:
|
|
f(*args + scope, **kwargs)
|
|
except Exception as e:
|
|
if not names:
|
|
names = count()
|
|
failures.append((dict(zip(names, scope)), e))
|
|
if failures:
|
|
raise SubTestFailures(*failures)
|
|
|
|
return wrapped
|
|
return dec
|
|
|
|
|
|
class MockDailyBarReader(object):
|
|
def spot_price(self, col, sid, dt):
|
|
return 100
|
|
|
|
|
|
def create_mock_adjustment_data(splits=None, dividends=None, mergers=None):
|
|
if splits is None:
|
|
splits = create_empty_splits_mergers_frame()
|
|
elif not isinstance(splits, pd.DataFrame):
|
|
splits = pd.DataFrame(splits)
|
|
|
|
if mergers is None:
|
|
mergers = create_empty_splits_mergers_frame()
|
|
elif not isinstance(mergers, pd.DataFrame):
|
|
mergers = pd.DataFrame(mergers)
|
|
|
|
if dividends is None:
|
|
dividends = create_empty_dividends_frame()
|
|
elif not isinstance(dividends, pd.DataFrame):
|
|
dividends = pd.DataFrame(dividends)
|
|
|
|
return splits, mergers, dividends
|
|
|
|
|
|
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
|
|
mergers=None):
|
|
path = tempdir.getpath("test_adjustments.db")
|
|
SQLiteAdjustmentWriter(path, MockDailyBarReader(), days).write(
|
|
*create_mock_adjustment_data(splits, dividends, mergers)
|
|
)
|
|
return path
|
|
|
|
|
|
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
|
|
"""
|
|
Assert that two pandas Timestamp objects are the same.
|
|
|
|
Parameters
|
|
----------
|
|
left, right : pd.Timestamp
|
|
The values to compare.
|
|
compare_nat_equal : bool, optional
|
|
Whether to consider `NaT` values equal. Defaults to True.
|
|
msg : str, optional
|
|
A message to forward to `pd.util.testing.assert_equal`.
|
|
"""
|
|
if compare_nat_equal and left is pd.NaT and right is pd.NaT:
|
|
return
|
|
return pd.util.testing.assert_equal(left, right, msg=msg)
|
|
|
|
|
|
def powerset(values):
|
|
"""
|
|
Return the power set (i.e., the set of all subsets) of entries in `values`.
|
|
"""
|
|
return concat(combinations(values, i) for i in range(len(values) + 1))
|
|
|
|
|
|
def to_series(knowledge_dates, earning_dates):
|
|
"""
|
|
Helper for converting a dict of strings to a Series of datetimes.
|
|
|
|
This is just for making the test cases more readable.
|
|
"""
|
|
return pd.Series(
|
|
index=pd.to_datetime(knowledge_dates),
|
|
data=pd.to_datetime(earning_dates),
|
|
)
|
|
|
|
|
|
def gen_calendars(start, stop, critical_dates):
|
|
"""
|
|
Generate calendars to use as inputs.
|
|
"""
|
|
all_dates = pd.date_range(start, stop, tz='utc')
|
|
for to_drop in map(list, powerset(critical_dates)):
|
|
# Have to yield tuples.
|
|
yield (all_dates.drop(to_drop),)
|
|
|
|
# Also test with the trading calendar.
|
|
yield (trading_days[trading_days.slice_indexer(start, stop)],)
|
|
|
|
|
|
@contextmanager
|
|
def temp_pipeline_engine(calendar, sids, random_seed, symbols=None):
|
|
"""
|
|
A contextManager that yields a SimplePipelineEngine holding a reference to
|
|
an AssetFinder generated via tmp_asset_finder.
|
|
|
|
Parameters
|
|
----------
|
|
calendar : pd.DatetimeIndex
|
|
Calendar to pass to the constructed PipelineEngine.
|
|
sids : iterable[int]
|
|
Sids to use for the temp asset finder.
|
|
random_seed : int
|
|
Integer used to seed instances of SeededRandomLoader.
|
|
symbols : iterable[str], optional
|
|
Symbols for constructed assets. Forwarded to make_simple_equity_info.
|
|
"""
|
|
equity_info = make_simple_equity_info(
|
|
sids=sids,
|
|
start_date=calendar[0],
|
|
end_date=calendar[-1],
|
|
symbols=symbols,
|
|
)
|
|
|
|
loader = make_seeded_random_loader(random_seed, calendar, sids)
|
|
get_loader = lambda column: loader
|
|
|
|
with tmp_asset_finder(equities=equity_info) as finder:
|
|
yield SimplePipelineEngine(get_loader, calendar, finder)
|
|
|
|
|
|
def parameter_space(__fail_fast=False, **params):
|
|
"""
|
|
Wrapper around subtest that allows passing keywords mapping names to
|
|
iterables of values.
|
|
|
|
The decorated test function will be called with the cross-product of all
|
|
possible inputs
|
|
|
|
Usage
|
|
-----
|
|
>>> from unittest import TestCase
|
|
>>> class SomeTestCase(TestCase):
|
|
... @parameter_space(x=[1, 2], y=[2, 3])
|
|
... def test_some_func(self, x, y):
|
|
... # Will be called with every possible combination of x and y.
|
|
... self.assertEqual(somefunc(x, y), expected_result(x, y))
|
|
|
|
See Also
|
|
--------
|
|
zipline.testing.subtest
|
|
"""
|
|
def decorator(f):
|
|
|
|
argspec = getargspec(f)
|
|
if argspec.varargs:
|
|
raise AssertionError("parameter_space() doesn't support *args")
|
|
if argspec.keywords:
|
|
raise AssertionError("parameter_space() doesn't support **kwargs")
|
|
if argspec.defaults:
|
|
raise AssertionError("parameter_space() doesn't support defaults.")
|
|
|
|
# Skip over implicit self.
|
|
argnames = argspec.args
|
|
if argnames[0] == 'self':
|
|
argnames = argnames[1:]
|
|
|
|
extra = set(params) - set(argnames)
|
|
if extra:
|
|
raise AssertionError(
|
|
"Keywords %s supplied to parameter_space() are "
|
|
"not in function signature." % extra
|
|
)
|
|
|
|
unspecified = set(argnames) - set(params)
|
|
if unspecified:
|
|
raise AssertionError(
|
|
"Function arguments %s were not "
|
|
"supplied to parameter_space()." % extra
|
|
)
|
|
|
|
param_sets = product(*(params[name] for name in argnames))
|
|
|
|
if __fail_fast:
|
|
@wraps(f)
|
|
def wrapped(self):
|
|
for args in param_sets:
|
|
f(self, *args)
|
|
return wrapped
|
|
else:
|
|
return subtest(param_sets, *argnames)(f)
|
|
|
|
return decorator
|
|
|
|
|
|
def create_empty_dividends_frame():
|
|
return pd.DataFrame(
|
|
np.array(
|
|
[],
|
|
dtype=[
|
|
('ex_date', 'datetime64[ns]'),
|
|
('pay_date', 'datetime64[ns]'),
|
|
('record_date', 'datetime64[ns]'),
|
|
('declared_date', 'datetime64[ns]'),
|
|
('amount', 'float64'),
|
|
('sid', 'int32'),
|
|
],
|
|
),
|
|
index=pd.DatetimeIndex([], tz='UTC'),
|
|
)
|
|
|
|
|
|
def create_empty_splits_mergers_frame():
|
|
return pd.DataFrame(
|
|
np.array(
|
|
[],
|
|
dtype=[
|
|
('effective_date', 'int64'),
|
|
('ratio', 'float64'),
|
|
('sid', 'int64'),
|
|
],
|
|
),
|
|
index=pd.DatetimeIndex([]),
|
|
)
|
|
|
|
|
|
@expect_dimensions(array=2)
|
|
def permute_rows(seed, array):
|
|
"""
|
|
Shuffle each row in ``array`` based on permutations generated by ``seed``.
|
|
|
|
Parameters
|
|
----------
|
|
seed : int
|
|
Seed for numpy.RandomState
|
|
array : np.ndarray[ndim=2]
|
|
Array over which to apply permutations.
|
|
"""
|
|
rand = np.random.RandomState(seed)
|
|
return np.apply_along_axis(rand.permutation, 1, array)
|
|
|
|
|
|
@nottest
|
|
def make_test_handler(testcase, *args, **kwargs):
|
|
"""
|
|
Returns a TestHandler which will be used by the given testcase. This
|
|
handler can be used to test log messages.
|
|
|
|
Parameters
|
|
----------
|
|
testcase: unittest.TestCase
|
|
The test class in which the log handler will be used.
|
|
*args, **kwargs
|
|
Forwarded to the new TestHandler object.
|
|
|
|
Returns
|
|
-------
|
|
handler: logbook.TestHandler
|
|
The handler to use for the test case.
|
|
"""
|
|
handler = TestHandler(*args, **kwargs)
|
|
testcase.addCleanup(handler.close)
|
|
return handler
|
|
|
|
|
|
def write_compressed(path, content):
|
|
"""
|
|
Write a compressed (gzipped) file to `path`.
|
|
"""
|
|
with gzip.open(path, 'wb') as f:
|
|
f.write(content)
|
|
|
|
|
|
def read_compressed(path):
|
|
"""
|
|
Write a compressed (gzipped) file from `path`.
|
|
"""
|
|
with gzip.open(path, 'rb') as f:
|
|
return f.read()
|
|
|
|
|
|
zipline_git_root = abspath(
|
|
join(realpath(dirname(__file__)), '..', '..'),
|
|
)
|
|
|
|
|
|
@nottest
|
|
def test_resource_path(*path_parts):
|
|
return os.path.join(zipline_git_root, 'tests', 'resources', *path_parts)
|
|
|
|
|
|
@contextmanager
|
|
def patch_os_environment(remove=None, **values):
|
|
"""
|
|
Context manager for patching the operating system environment.
|
|
"""
|
|
old_values = {}
|
|
remove = remove or []
|
|
for key in remove:
|
|
old_values[key] = os.environ.pop(key)
|
|
for key, value in values.iteritems():
|
|
old_values[key] = os.getenv(key)
|
|
os.environ[key] = value
|
|
try:
|
|
yield
|
|
finally:
|
|
for old_key, old_value in old_values.iteritems():
|
|
if old_value is None:
|
|
# Value was not present when we entered, so del it out if it's
|
|
# still present.
|
|
try:
|
|
del os.environ[key]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
# Restore the old value.
|
|
os.environ[old_key] = old_value
|
|
|
|
|
|
class tmp_dir(TempDirectory, object):
|
|
"""New style class that wrapper for TempDirectory in python 2.
|
|
"""
|
|
pass
|
|
|
|
|
|
class _TmpBarReader(with_metaclass(ABCMeta, tmp_dir)):
|
|
"""A helper for tmp_bcolz_minute_bar_reader and tmp_bcolz_daily_bar_reader.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : dict[int -> pd.DataFrame]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
"""
|
|
@abstractproperty
|
|
def _reader_cls(self):
|
|
raise NotImplementedError('_reader')
|
|
|
|
@abstractmethod
|
|
def _write(self, env, days, path, data):
|
|
raise NotImplementedError('_write')
|
|
|
|
def __init__(self, env, days, data, path=None):
|
|
super(_TmpBarReader, self).__init__(path=path)
|
|
self._env = env
|
|
self._days = days
|
|
self._data = data
|
|
|
|
def __enter__(self):
|
|
tmpdir = super(_TmpBarReader, self).__enter__()
|
|
env = self._env
|
|
try:
|
|
self._write(
|
|
env,
|
|
self._days,
|
|
tmpdir.path,
|
|
self._data,
|
|
)
|
|
return self._reader_cls(tmpdir.path)
|
|
except:
|
|
self.__exit__(None, None, None)
|
|
raise
|
|
|
|
|
|
class tmp_bcolz_minute_bar_reader(_TmpBarReader):
|
|
"""A temporary BcolzMinuteBarReader object.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : iterable[(int, pd.DataFrame)]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
|
|
See Also
|
|
--------
|
|
tmp_bcolz_daily_bar_reader
|
|
"""
|
|
_reader_cls = BcolzMinuteBarReader
|
|
_write = staticmethod(write_bcolz_minute_data)
|
|
|
|
|
|
class tmp_bcolz_daily_bar_reader(_TmpBarReader):
|
|
"""A temporary BcolzDailyBarReader object.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : dict[int -> pd.DataFrame]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
|
|
See Also
|
|
--------
|
|
tmp_bcolz_daily_bar_reader
|
|
"""
|
|
_reader_cls = BcolzDailyBarReader
|
|
|
|
@staticmethod
|
|
def _write(env, days, path, data):
|
|
BcolzDailyBarWriter(path, days).write(data)
|
|
|
|
|
|
@contextmanager
|
|
def patch_read_csv(url_map, module=pd, strict=False):
|
|
"""Patch pandas.read_csv to map lookups from url to another.
|
|
|
|
Parameters
|
|
----------
|
|
url_map : mapping[str or file-like object -> str or file-like object]
|
|
The mapping to use to redirect read_csv calls.
|
|
module : module, optional
|
|
The module to patch ``read_csv`` on. By default this is ``pandas``.
|
|
This should be set to another module if ``read_csv`` is early-bound
|
|
like ``from pandas import read_csv`` instead of late-bound like:
|
|
``import pandas as pd; pd.read_csv``.
|
|
strict : bool, optional
|
|
If true, then this will assert that ``read_csv`` is only called with
|
|
elements in the ``url_map``.
|
|
"""
|
|
read_csv = pd.read_csv
|
|
|
|
def patched_read_csv(filepath_or_buffer, *args, **kwargs):
|
|
if filepath_or_buffer in url_map:
|
|
return read_csv(url_map[filepath_or_buffer], *args, **kwargs)
|
|
elif not strict:
|
|
return read_csv(filepath_or_buffer, *args, **kwargs)
|
|
else:
|
|
raise AssertionError(
|
|
'attempted to call read_csv on %r which not in the url map' %
|
|
filepath_or_buffer,
|
|
)
|
|
|
|
with patch.object(module, 'read_csv', patched_read_csv):
|
|
yield
|