mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 02:54:16 +08:00
71a34bf7ac
The daily/session bar reader's `spot_price` took the same parameters and returned the same kind of output as the minute bar reader's `get_value`. Standardize on one method to make a common interface, which may be formally factored out in a later patch; to help enable writing reader implementations or mixins which can be agnostic to the bar frequency.
1521 lines
44 KiB
Python
1521 lines
44 KiB
Python
from abc import ABCMeta, abstractmethod, abstractproperty
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
import gzip
|
|
from inspect import getargspec
|
|
from itertools import (
|
|
combinations,
|
|
count,
|
|
product,
|
|
)
|
|
import operator
|
|
import os
|
|
from os.path import abspath, dirname, join, realpath
|
|
import shutil
|
|
from sys import _getframe
|
|
import tempfile
|
|
|
|
from logbook import TestHandler
|
|
from mock import patch
|
|
from nose.tools import nottest
|
|
from numpy.testing import assert_allclose, assert_array_equal
|
|
import pandas as pd
|
|
from six import itervalues, iteritems, with_metaclass
|
|
from six.moves import filter, map
|
|
from sqlalchemy import create_engine
|
|
from testfixtures import TempDirectory
|
|
from toolz import concat, curry
|
|
|
|
from zipline.assets import AssetFinder, AssetDBWriter
|
|
from zipline.assets.synthetic import make_simple_equity_info
|
|
from zipline.data.data_portal import DataPortal
|
|
from zipline.data.minute_bars import (
|
|
BcolzMinuteBarReader,
|
|
BcolzMinuteBarWriter,
|
|
US_EQUITIES_MINUTES_PER_DAY
|
|
)
|
|
from zipline.data.us_equity_pricing import (
|
|
BcolzDailyBarReader,
|
|
BcolzDailyBarWriter,
|
|
SQLiteAdjustmentWriter,
|
|
)
|
|
from zipline.finance.trading import TradingEnvironment
|
|
from zipline.finance.order import ORDER_STATUS
|
|
from zipline.lib.labelarray import LabelArray
|
|
from zipline.pipeline.data import USEquityPricing
|
|
from zipline.pipeline.engine import SimplePipelineEngine
|
|
from zipline.pipeline.factors import CustomFactor
|
|
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
|
from zipline.utils import security_list
|
|
from zipline.utils.calendars import get_calendar
|
|
from zipline.utils.input_validation import expect_dimensions
|
|
from zipline.utils.numpy_utils import as_column
|
|
from zipline.utils.sentinel import sentinel
|
|
|
|
import numpy as np
|
|
from numpy import float64
|
|
|
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC')
|
|
|
|
|
|
def seconds_to_timestamp(seconds):
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC')
|
|
|
|
|
|
def to_utc(time_str):
|
|
"""Convert a string in US/Eastern time to UTC"""
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
|
|
|
|
|
|
def str_to_seconds(s):
|
|
"""
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC.
|
|
|
|
>>> from pandas import Timestamp
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
|
|
1388534400.0
|
|
>>> str_to_seconds('2014-01-01')
|
|
1388534400
|
|
"""
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
|
|
|
|
|
|
def drain_zipline(test, zipline):
|
|
output = []
|
|
transaction_count = 0
|
|
msg_counter = 0
|
|
# start the simulation
|
|
for update in zipline:
|
|
msg_counter += 1
|
|
output.append(update)
|
|
if 'daily_perf' in update:
|
|
transaction_count += \
|
|
len(update['daily_perf']['transactions'])
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
def check_algo_results(test,
|
|
results,
|
|
expected_transactions_count=None,
|
|
expected_order_count=None,
|
|
expected_positions_count=None,
|
|
sid=None):
|
|
|
|
if expected_transactions_count is not None:
|
|
txns = flatten_list(results["transactions"])
|
|
test.assertEqual(expected_transactions_count, len(txns))
|
|
|
|
if expected_positions_count is not None:
|
|
raise NotImplementedError
|
|
|
|
if expected_order_count is not None:
|
|
# de-dup orders on id, because orders are put back into perf packets
|
|
# whenever they a txn is filled
|
|
orders = set([order['id'] for order in
|
|
flatten_list(results["orders"])])
|
|
|
|
test.assertEqual(expected_order_count, len(orders))
|
|
|
|
|
|
def flatten_list(list):
|
|
return [item for sublist in list for item in sublist]
|
|
|
|
|
|
def assert_single_position(test, zipline):
|
|
|
|
output, transaction_count = drain_zipline(test, zipline)
|
|
|
|
if 'expected_transactions' in test.zipline_test_config:
|
|
test.assertEqual(
|
|
test.zipline_test_config['expected_transactions'],
|
|
transaction_count
|
|
)
|
|
else:
|
|
test.assertEqual(
|
|
test.zipline_test_config['order_count'],
|
|
transaction_count
|
|
)
|
|
|
|
# the final message is the risk report, the second to
|
|
# last is the final day's results. Positions is a list of
|
|
# dicts.
|
|
closing_positions = output[-2]['daily_perf']['positions']
|
|
|
|
# confirm that all orders were filled.
|
|
# iterate over the output updates, overwriting
|
|
# orders when they are updated. Then check the status on all.
|
|
orders_by_id = {}
|
|
for update in output:
|
|
if 'daily_perf' in update:
|
|
if 'orders' in update['daily_perf']:
|
|
for order in update['daily_perf']['orders']:
|
|
orders_by_id[order['id']] = order
|
|
|
|
for order in itervalues(orders_by_id):
|
|
test.assertEqual(
|
|
order['status'],
|
|
ORDER_STATUS.FILLED,
|
|
"")
|
|
|
|
test.assertEqual(
|
|
len(closing_positions),
|
|
1,
|
|
"Portfolio should have one position."
|
|
)
|
|
|
|
sid = test.zipline_test_config['sid']
|
|
test.assertEqual(
|
|
closing_positions[0]['sid'],
|
|
sid,
|
|
"Portfolio should have one position in " + str(sid)
|
|
)
|
|
|
|
return output, transaction_count
|
|
|
|
|
|
@contextmanager
|
|
def security_list_copy():
|
|
old_dir = security_list.SECURITY_LISTS_DIR
|
|
new_dir = tempfile.mkdtemp()
|
|
try:
|
|
for subdir in os.listdir(old_dir):
|
|
shutil.copytree(os.path.join(old_dir, subdir),
|
|
os.path.join(new_dir, subdir))
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
|
|
patch.object(security_list, 'using_copy', True,
|
|
create=True):
|
|
yield
|
|
finally:
|
|
shutil.rmtree(new_dir, True)
|
|
|
|
|
|
def add_security_data(adds, deletes):
|
|
if not hasattr(security_list, 'using_copy'):
|
|
raise Exception('add_security_data must be used within '
|
|
'security_list_copy context')
|
|
directory = os.path.join(
|
|
security_list.SECURITY_LISTS_DIR,
|
|
"leveraged_etf_list/20150127/20150125"
|
|
)
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
del_path = os.path.join(directory, "delete")
|
|
with open(del_path, 'w') as f:
|
|
for sym in deletes:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
add_path = os.path.join(directory, "add")
|
|
with open(add_path, 'w') as f:
|
|
for sym in adds:
|
|
f.write(sym)
|
|
f.write('\n')
|
|
|
|
|
|
def all_pairs_matching_predicate(values, pred):
|
|
"""
|
|
Return an iterator of all pairs, (v0, v1) from values such that
|
|
|
|
`pred(v0, v1) == True`
|
|
|
|
Parameters
|
|
----------
|
|
values : iterable
|
|
pred : function
|
|
|
|
Returns
|
|
-------
|
|
pairs_iterator : generator
|
|
Generator yielding pairs matching `pred`.
|
|
|
|
Examples
|
|
--------
|
|
>>> from zipline.testing import all_pairs_matching_predicate
|
|
>>> from operator import eq, lt
|
|
>>> list(all_pairs_matching_predicate(range(5), eq))
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
|
>>> list(all_pairs_matching_predicate("abcd", lt))
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
|
|
"""
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2))
|
|
|
|
|
|
def product_upper_triangle(values, include_diagonal=False):
|
|
"""
|
|
Return an iterator over pairs, (v0, v1), drawn from values.
|
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1.
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1.
|
|
"""
|
|
return all_pairs_matching_predicate(
|
|
values,
|
|
operator.le if include_diagonal else operator.lt,
|
|
)
|
|
|
|
|
|
def all_subindices(index):
|
|
"""
|
|
Return all valid sub-indices of a pandas Index.
|
|
"""
|
|
return (
|
|
index[start:stop]
|
|
for start, stop in product_upper_triangle(range(len(index) + 1))
|
|
)
|
|
|
|
|
|
def chrange(start, stop):
|
|
"""
|
|
Construct an iterable of length-1 strings beginning with `start` and ending
|
|
with `stop`.
|
|
|
|
Parameters
|
|
----------
|
|
start : str
|
|
The first character.
|
|
stop : str
|
|
The last character.
|
|
|
|
Returns
|
|
-------
|
|
chars: iterable[str]
|
|
Iterable of strings beginning with start and ending with stop.
|
|
|
|
Example
|
|
-------
|
|
>>> chrange('A', 'C')
|
|
['A', 'B', 'C']
|
|
"""
|
|
return list(map(chr, range(ord(start), ord(stop) + 1)))
|
|
|
|
|
|
def make_trade_data_for_asset_info(dates,
|
|
asset_info,
|
|
price_start,
|
|
price_step_by_date,
|
|
price_step_by_sid,
|
|
volume_start,
|
|
volume_step_by_date,
|
|
volume_step_by_sid,
|
|
frequency,
|
|
writer=None):
|
|
"""
|
|
Convert the asset info dataframe into a dataframe of trade data for each
|
|
sid, and write to the writer if provided. Write NaNs for locations where
|
|
assets did not exist. Return a dict of the dataframes, keyed by sid.
|
|
"""
|
|
trade_data = {}
|
|
sids = asset_info.index
|
|
|
|
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
|
|
price_date_deltas = (np.arange(len(dates), dtype=float64) *
|
|
price_step_by_date)
|
|
prices = (price_sid_deltas + as_column(price_date_deltas)) + price_start
|
|
|
|
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
|
|
volume_date_deltas = np.arange(len(dates)) * volume_step_by_date
|
|
volumes = volume_sid_deltas + as_column(volume_date_deltas) + volume_start
|
|
|
|
for j, sid in enumerate(sids):
|
|
start_date, end_date = asset_info.loc[sid, ['start_date', 'end_date']]
|
|
# Normalize here so the we still generate non-NaN values on the minutes
|
|
# for an asset's last trading day.
|
|
for i, date in enumerate(dates.normalize()):
|
|
if not (start_date <= date <= end_date):
|
|
prices[i, j] = 0
|
|
volumes[i, j] = 0
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": prices[:, j],
|
|
"high": prices[:, j],
|
|
"low": prices[:, j],
|
|
"close": prices[:, j],
|
|
"volume": volumes[:, j],
|
|
},
|
|
index=dates,
|
|
)
|
|
|
|
if writer:
|
|
writer.write_sid(sid, df)
|
|
|
|
trade_data[sid] = df
|
|
|
|
return trade_data
|
|
|
|
|
|
def check_allclose(actual,
|
|
desired,
|
|
rtol=1e-07,
|
|
atol=0,
|
|
err_msg='',
|
|
verbose=True):
|
|
"""
|
|
Wrapper around np.testing.assert_allclose that also verifies that inputs
|
|
are ndarrays.
|
|
|
|
See Also
|
|
--------
|
|
np.assert_allclose
|
|
"""
|
|
if type(actual) != type(desired):
|
|
raise AssertionError("%s != %s" % (type(actual), type(desired)))
|
|
return assert_allclose(
|
|
actual,
|
|
desired,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
err_msg=err_msg,
|
|
verbose=verbose,
|
|
)
|
|
|
|
|
|
def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
|
|
"""
|
|
Wrapper around np.testing.assert_array_equal that also verifies that inputs
|
|
are ndarrays.
|
|
|
|
See Also
|
|
--------
|
|
np.assert_array_equal
|
|
"""
|
|
assert type(x) == type(y), "{x} != {y}".format(x=type(x), y=type(y))
|
|
assert x.dtype == y.dtype, "{x.dtype} != {y.dtype}".format(x=x, y=y)
|
|
|
|
if isinstance(x, LabelArray):
|
|
# Check that both arrays have missing values in the same locations...
|
|
assert_array_equal(
|
|
x.is_missing(),
|
|
y.is_missing(),
|
|
err_msg=err_msg,
|
|
verbose=verbose,
|
|
)
|
|
# ...then check the actual values as well.
|
|
x = x.as_string_array()
|
|
y = y.as_string_array()
|
|
|
|
return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose)
|
|
|
|
|
|
class UnexpectedAttributeAccess(Exception):
|
|
pass
|
|
|
|
|
|
class ExplodingObject(object):
|
|
"""
|
|
Object that will raise an exception on any attribute access.
|
|
|
|
Useful for verifying that an object is never touched during a
|
|
function/method call.
|
|
"""
|
|
def __getattribute__(self, name):
|
|
raise UnexpectedAttributeAccess(name)
|
|
|
|
|
|
def write_minute_data(trading_calendar, tempdir, minutes, sids):
|
|
first_session = trading_calendar.minute_to_session_label(
|
|
minutes[0], direction="none"
|
|
)
|
|
last_session = trading_calendar.minute_to_session_label(
|
|
minutes[-1], direction="none"
|
|
)
|
|
|
|
sessions = trading_calendar.sessions_in_range(first_session, last_session)
|
|
|
|
write_bcolz_minute_data(
|
|
trading_calendar,
|
|
sessions,
|
|
tempdir.path,
|
|
create_minute_bar_data(minutes, sids),
|
|
)
|
|
return tempdir.path
|
|
|
|
|
|
def create_minute_bar_data(minutes, sids):
|
|
length = len(minutes)
|
|
for sid_idx, sid in enumerate(sids):
|
|
yield sid, pd.DataFrame(
|
|
{
|
|
'open': np.arange(length) + 10 + sid_idx,
|
|
'high': np.arange(length) + 15 + sid_idx,
|
|
'low': np.arange(length) + 8 + sid_idx,
|
|
'close': np.arange(length) + 10 + sid_idx,
|
|
'volume': 100 + sid_idx,
|
|
},
|
|
index=minutes,
|
|
)
|
|
|
|
|
|
def create_daily_bar_data(sessions, sids):
|
|
length = len(sessions)
|
|
for sid_idx, sid in enumerate(sids):
|
|
yield sid, pd.DataFrame(
|
|
{
|
|
"open": (np.array(range(10, 10 + length)) + sid_idx),
|
|
"high": (np.array(range(15, 15 + length)) + sid_idx),
|
|
"low": (np.array(range(8, 8 + length)) + sid_idx),
|
|
"close": (np.array(range(10, 10 + length)) + sid_idx),
|
|
"volume": np.array(range(100, 100 + length)) + sid_idx,
|
|
"day": [session.value for session in sessions]
|
|
},
|
|
index=sessions,
|
|
)
|
|
|
|
|
|
def write_daily_data(tempdir, sim_params, sids, trading_calendar):
|
|
path = os.path.join(tempdir.path, "testdaily.bcolz")
|
|
BcolzDailyBarWriter(path, trading_calendar,
|
|
sim_params.start_session,
|
|
sim_params.end_session).write(
|
|
create_daily_bar_data(sim_params.sessions, sids),
|
|
)
|
|
|
|
return path
|
|
|
|
|
|
def create_data_portal(asset_finder, tempdir, sim_params, sids,
|
|
trading_calendar, adjustment_reader=None):
|
|
if sim_params.data_frequency == "daily":
|
|
daily_path = write_daily_data(tempdir, sim_params, sids,
|
|
trading_calendar)
|
|
|
|
equity_daily_reader = BcolzDailyBarReader(daily_path)
|
|
|
|
return DataPortal(
|
|
asset_finder, trading_calendar,
|
|
first_trading_day=equity_daily_reader.first_trading_day,
|
|
equity_daily_reader=equity_daily_reader,
|
|
adjustment_reader=adjustment_reader
|
|
)
|
|
else:
|
|
minutes = trading_calendar.minutes_in_range(
|
|
sim_params.first_open,
|
|
sim_params.last_close
|
|
)
|
|
|
|
minute_path = write_minute_data(trading_calendar, tempdir, minutes,
|
|
sids)
|
|
|
|
equity_minute_reader = BcolzMinuteBarReader(minute_path)
|
|
|
|
return DataPortal(
|
|
asset_finder, trading_calendar,
|
|
first_trading_day=equity_minute_reader.first_trading_day,
|
|
equity_minute_reader=equity_minute_reader,
|
|
adjustment_reader=adjustment_reader
|
|
)
|
|
|
|
|
|
def write_bcolz_minute_data(trading_calendar, days, path, data):
|
|
BcolzMinuteBarWriter(
|
|
path,
|
|
trading_calendar,
|
|
days[0],
|
|
days[-1],
|
|
US_EQUITIES_MINUTES_PER_DAY
|
|
).write(data)
|
|
|
|
|
|
def create_minute_df_for_asset(trading_calendar,
|
|
start_dt,
|
|
end_dt,
|
|
interval=1,
|
|
start_val=1,
|
|
minute_blacklist=None):
|
|
|
|
asset_minutes = trading_calendar.minutes_for_sessions_in_range(
|
|
start_dt, end_dt
|
|
)
|
|
minutes_count = len(asset_minutes)
|
|
minutes_arr = np.array(range(start_val, start_val + minutes_count))
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": minutes_arr + 1,
|
|
"high": minutes_arr + 2,
|
|
"low": minutes_arr - 1,
|
|
"close": minutes_arr,
|
|
"volume": 100 * minutes_arr,
|
|
},
|
|
index=asset_minutes,
|
|
)
|
|
|
|
if interval > 1:
|
|
counter = 0
|
|
while counter < len(minutes_arr):
|
|
df[counter:(counter + interval - 1)] = 0
|
|
counter += interval
|
|
|
|
if minute_blacklist is not None:
|
|
for minute in minute_blacklist:
|
|
df.loc[minute] = 0
|
|
|
|
return df
|
|
|
|
|
|
def create_daily_df_for_asset(trading_calendar, start_day, end_day,
|
|
interval=1):
|
|
days = trading_calendar.sessions_in_range(start_day, end_day)
|
|
days_count = len(days)
|
|
days_arr = np.arange(days_count) + 2
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"open": days_arr + 1,
|
|
"high": days_arr + 2,
|
|
"low": days_arr - 1,
|
|
"close": days_arr,
|
|
"volume": days_arr * 100,
|
|
},
|
|
index=days,
|
|
)
|
|
|
|
if interval > 1:
|
|
# only keep every 'interval' rows
|
|
for idx, _ in enumerate(days_arr):
|
|
if (idx + 1) % interval != 0:
|
|
df["open"].iloc[idx] = 0
|
|
df["high"].iloc[idx] = 0
|
|
df["low"].iloc[idx] = 0
|
|
df["close"].iloc[idx] = 0
|
|
df["volume"].iloc[idx] = 0
|
|
|
|
return df
|
|
|
|
|
|
def trades_by_sid_to_dfs(trades_by_sid, index):
|
|
for sidint, trades in iteritems(trades_by_sid):
|
|
opens = []
|
|
highs = []
|
|
lows = []
|
|
closes = []
|
|
volumes = []
|
|
for trade in trades:
|
|
opens.append(trade["open_price"])
|
|
highs.append(trade["high"])
|
|
lows.append(trade["low"])
|
|
closes.append(trade["close_price"])
|
|
volumes.append(trade["volume"])
|
|
|
|
yield sidint, pd.DataFrame(
|
|
{
|
|
"open": opens,
|
|
"high": highs,
|
|
"low": lows,
|
|
"close": closes,
|
|
"volume": volumes,
|
|
},
|
|
index=index,
|
|
)
|
|
|
|
|
|
def create_data_portal_from_trade_history(asset_finder, trading_calendar,
|
|
tempdir, sim_params, trades_by_sid):
|
|
if sim_params.data_frequency == "daily":
|
|
path = os.path.join(tempdir.path, "testdaily.bcolz")
|
|
writer = BcolzDailyBarWriter(
|
|
path, trading_calendar,
|
|
sim_params.start_session,
|
|
sim_params.end_session
|
|
)
|
|
writer.write(
|
|
trades_by_sid_to_dfs(trades_by_sid, sim_params.sessions),
|
|
)
|
|
|
|
equity_daily_reader = BcolzDailyBarReader(path)
|
|
|
|
return DataPortal(
|
|
asset_finder, trading_calendar,
|
|
first_trading_day=equity_daily_reader.first_trading_day,
|
|
equity_daily_reader=equity_daily_reader,
|
|
)
|
|
else:
|
|
minutes = trading_calendar.minutes_in_range(
|
|
sim_params.first_open,
|
|
sim_params.last_close
|
|
)
|
|
|
|
length = len(minutes)
|
|
assets = {}
|
|
|
|
for sidint, trades in iteritems(trades_by_sid):
|
|
opens = np.zeros(length)
|
|
highs = np.zeros(length)
|
|
lows = np.zeros(length)
|
|
closes = np.zeros(length)
|
|
volumes = np.zeros(length)
|
|
|
|
for trade in trades:
|
|
# put them in the right place
|
|
idx = minutes.searchsorted(trade.dt)
|
|
|
|
opens[idx] = trade.open_price * 1000
|
|
highs[idx] = trade.high * 1000
|
|
lows[idx] = trade.low * 1000
|
|
closes[idx] = trade.close_price * 1000
|
|
volumes[idx] = trade.volume
|
|
|
|
assets[sidint] = pd.DataFrame({
|
|
"open": opens,
|
|
"high": highs,
|
|
"low": lows,
|
|
"close": closes,
|
|
"volume": volumes,
|
|
"dt": minutes
|
|
}).set_index("dt")
|
|
|
|
write_bcolz_minute_data(
|
|
trading_calendar,
|
|
sim_params.sessions,
|
|
tempdir.path,
|
|
assets
|
|
)
|
|
|
|
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
|
|
|
|
return DataPortal(
|
|
asset_finder, trading_calendar,
|
|
first_trading_day=equity_minute_reader.first_trading_day,
|
|
equity_minute_reader=equity_minute_reader,
|
|
)
|
|
|
|
|
|
class FakeDataPortal(DataPortal):
|
|
def __init__(self, env=None, trading_calendar=None,
|
|
first_trading_day=None):
|
|
if env is None:
|
|
env = TradingEnvironment()
|
|
|
|
if trading_calendar is None:
|
|
trading_calendar = get_calendar("NYSE")
|
|
|
|
super(FakeDataPortal, self).__init__(env.asset_finder,
|
|
trading_calendar,
|
|
first_trading_day)
|
|
|
|
def get_spot_value(self, asset, field, dt, data_frequency):
|
|
if field == "volume":
|
|
return 100
|
|
else:
|
|
return 1.0
|
|
|
|
def get_history_window(self, assets, end_dt, bar_count, frequency, field,
|
|
ffill=True):
|
|
if frequency == "1d":
|
|
end_idx = \
|
|
self.trading_calendar.all_sessions.searchsorted(end_dt)
|
|
days = self.trading_calendar.all_sessions[
|
|
(end_idx - bar_count + 1):(end_idx + 1)
|
|
]
|
|
|
|
df = pd.DataFrame(
|
|
np.full((bar_count, len(assets)), 100),
|
|
index=days,
|
|
columns=assets
|
|
)
|
|
|
|
return df
|
|
|
|
|
|
class FetcherDataPortal(DataPortal):
|
|
"""
|
|
Mock dataportal that returns fake data for history and non-fetcher
|
|
spot value.
|
|
"""
|
|
def __init__(self, asset_finder, trading_calendar, first_trading_day=None):
|
|
super(FetcherDataPortal, self).__init__(asset_finder, trading_calendar,
|
|
first_trading_day)
|
|
|
|
def get_spot_value(self, asset, field, dt, data_frequency):
|
|
# if this is a fetcher field, exercise the regular code path
|
|
if self._is_extra_source(asset, field, self._augmented_sources_map):
|
|
return super(FetcherDataPortal, self).get_spot_value(
|
|
asset, field, dt, data_frequency)
|
|
|
|
# otherwise just return a fixed value
|
|
return int(asset)
|
|
|
|
def _get_daily_window_for_sid(self, asset, field, days_in_window,
|
|
extra_slot=True):
|
|
return np.arange(days_in_window, dtype=np.float64)
|
|
|
|
def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
|
|
return np.arange(minutes_for_window, dtype=np.float64)
|
|
|
|
|
|
class tmp_assets_db(object):
|
|
"""Create a temporary assets sqlite database.
|
|
This is meant to be used as a context manager.
|
|
|
|
Parameters
|
|
----------
|
|
url : string
|
|
The URL for the database connection.
|
|
**frames
|
|
The frames to pass to the AssetDBWriter.
|
|
By default this maps equities:
|
|
('A', 'B', 'C') -> map(ord, 'ABC')
|
|
|
|
See Also
|
|
--------
|
|
empty_assets_db
|
|
tmp_asset_finder
|
|
"""
|
|
_default_equities = sentinel('_default_equities')
|
|
|
|
def __init__(self,
|
|
url='sqlite:///:memory:',
|
|
equities=_default_equities,
|
|
**frames):
|
|
self._url = url
|
|
self._eng = None
|
|
if equities is self._default_equities:
|
|
equities = make_simple_equity_info(
|
|
list(map(ord, 'ABC')),
|
|
pd.Timestamp(0),
|
|
pd.Timestamp('2015'),
|
|
)
|
|
|
|
frames['equities'] = equities
|
|
self._frames = frames
|
|
self._eng = None # set in enter and exit
|
|
|
|
def __enter__(self):
|
|
self._eng = eng = create_engine(self._url)
|
|
AssetDBWriter(eng).write(**self._frames)
|
|
return eng
|
|
|
|
def __exit__(self, *excinfo):
|
|
assert self._eng is not None, '_eng was not set in __enter__'
|
|
self._eng.dispose()
|
|
self._eng = None
|
|
|
|
|
|
def empty_assets_db():
|
|
"""Context manager for creating an empty assets db.
|
|
|
|
See Also
|
|
--------
|
|
tmp_assets_db
|
|
"""
|
|
return tmp_assets_db(equities=None)
|
|
|
|
|
|
class tmp_asset_finder(tmp_assets_db):
|
|
"""Create a temporary asset finder using an in memory sqlite db.
|
|
|
|
Parameters
|
|
----------
|
|
url : string
|
|
The URL for the database connection.
|
|
finder_cls : type, optional
|
|
The type of asset finder to create from the assets db.
|
|
**frames
|
|
Forwarded to ``tmp_assets_db``.
|
|
|
|
See Also
|
|
--------
|
|
tmp_assets_db
|
|
"""
|
|
def __init__(self,
|
|
url='sqlite:///:memory:',
|
|
finder_cls=AssetFinder,
|
|
**frames):
|
|
self._finder_cls = finder_cls
|
|
super(tmp_asset_finder, self).__init__(url=url, **frames)
|
|
|
|
def __enter__(self):
|
|
return self._finder_cls(super(tmp_asset_finder, self).__enter__())
|
|
|
|
|
|
def empty_asset_finder():
|
|
"""Context manager for creating an empty asset finder.
|
|
|
|
See Also
|
|
--------
|
|
empty_assets_db
|
|
tmp_assets_db
|
|
tmp_asset_finder
|
|
"""
|
|
return tmp_asset_finder(equities=None)
|
|
|
|
|
|
class tmp_trading_env(tmp_asset_finder):
|
|
"""Create a temporary trading environment.
|
|
|
|
Parameters
|
|
----------
|
|
finder_cls : type, optional
|
|
The type of asset finder to create from the assets db.
|
|
**frames
|
|
Forwarded to ``tmp_assets_db``.
|
|
|
|
See Also
|
|
--------
|
|
empty_trading_env
|
|
tmp_asset_finder
|
|
"""
|
|
def __enter__(self):
|
|
return TradingEnvironment(
|
|
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
|
|
)
|
|
|
|
|
|
def empty_trading_env():
|
|
return tmp_trading_env(equities=None)
|
|
|
|
|
|
class SubTestFailures(AssertionError):
|
|
def __init__(self, *failures):
|
|
self.failures = failures
|
|
|
|
def __str__(self):
|
|
return 'failures:\n %s' % '\n '.join(
|
|
'\n '.join((
|
|
', '.join('%s=%r' % item for item in scope.items()),
|
|
'%s: %s' % (type(exc).__name__, exc),
|
|
)) for scope, exc in self.failures,
|
|
)
|
|
|
|
|
|
@nottest
|
|
def subtest(iterator, *_names):
|
|
"""
|
|
Construct a subtest in a unittest.
|
|
|
|
Consider using ``zipline.testing.parameter_space`` when subtests
|
|
are constructed over a single input or over the cross-product of multiple
|
|
inputs.
|
|
|
|
``subtest`` works by decorating a function as a subtest. The decorated
|
|
function will be run by iterating over the ``iterator`` and *unpacking the
|
|
values into the function. If any of the runs fail, the result will be put
|
|
into a set and the rest of the tests will be run. Finally, if any failed,
|
|
all of the results will be dumped as one failure.
|
|
|
|
Parameters
|
|
----------
|
|
iterator : iterable[iterable]
|
|
The iterator of arguments to pass to the function.
|
|
*name : iterator[str]
|
|
The names to use for each element of ``iterator``. These will be used
|
|
to print the scope when a test fails. If not provided, it will use the
|
|
integer index of the value as the name.
|
|
|
|
Examples
|
|
--------
|
|
|
|
::
|
|
|
|
class MyTest(TestCase):
|
|
def test_thing(self):
|
|
# Example usage inside another test.
|
|
@subtest(([n] for n in range(100000)), 'n')
|
|
def subtest(n):
|
|
self.assertEqual(n % 2, 0, 'n was not even')
|
|
subtest()
|
|
|
|
@subtest(([n] for n in range(100000)), 'n')
|
|
def test_decorated_function(self, n):
|
|
# Example usage to parameterize an entire function.
|
|
self.assertEqual(n % 2, 1, 'n was not odd')
|
|
|
|
Notes
|
|
-----
|
|
We use this when we:
|
|
|
|
* Will never want to run each parameter individually.
|
|
* Have a large parameter space we are testing
|
|
(see tests/utils/test_events.py).
|
|
|
|
``nose_parameterized.expand`` will create a test for each parameter
|
|
combination which bloats the test output and makes the travis pages slow.
|
|
|
|
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
|
|
nose2 do not support ``addSubTest``.
|
|
|
|
See Also
|
|
--------
|
|
zipline.testing.parameter_space
|
|
"""
|
|
def dec(f):
|
|
@wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
names = _names
|
|
failures = []
|
|
for scope in iterator:
|
|
scope = tuple(scope)
|
|
try:
|
|
f(*args + scope, **kwargs)
|
|
except Exception as e:
|
|
if not names:
|
|
names = count()
|
|
failures.append((dict(zip(names, scope)), e))
|
|
if failures:
|
|
raise SubTestFailures(*failures)
|
|
|
|
return wrapped
|
|
return dec
|
|
|
|
|
|
class MockDailyBarReader(object):
|
|
def get_value(self, col, sid, dt):
|
|
return 100
|
|
|
|
|
|
def create_mock_adjustment_data(splits=None, dividends=None, mergers=None):
|
|
if splits is None:
|
|
splits = create_empty_splits_mergers_frame()
|
|
elif not isinstance(splits, pd.DataFrame):
|
|
splits = pd.DataFrame(splits)
|
|
|
|
if mergers is None:
|
|
mergers = create_empty_splits_mergers_frame()
|
|
elif not isinstance(mergers, pd.DataFrame):
|
|
mergers = pd.DataFrame(mergers)
|
|
|
|
if dividends is None:
|
|
dividends = create_empty_dividends_frame()
|
|
elif not isinstance(dividends, pd.DataFrame):
|
|
dividends = pd.DataFrame(dividends)
|
|
|
|
return splits, mergers, dividends
|
|
|
|
|
|
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
|
|
mergers=None):
|
|
path = tempdir.getpath("test_adjustments.db")
|
|
SQLiteAdjustmentWriter(path, MockDailyBarReader(), days).write(
|
|
*create_mock_adjustment_data(splits, dividends, mergers)
|
|
)
|
|
return path
|
|
|
|
|
|
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
|
|
"""
|
|
Assert that two pandas Timestamp objects are the same.
|
|
|
|
Parameters
|
|
----------
|
|
left, right : pd.Timestamp
|
|
The values to compare.
|
|
compare_nat_equal : bool, optional
|
|
Whether to consider `NaT` values equal. Defaults to True.
|
|
msg : str, optional
|
|
A message to forward to `pd.util.testing.assert_equal`.
|
|
"""
|
|
if compare_nat_equal and left is pd.NaT and right is pd.NaT:
|
|
return
|
|
return pd.util.testing.assert_equal(left, right, msg=msg)
|
|
|
|
|
|
def powerset(values):
|
|
"""
|
|
Return the power set (i.e., the set of all subsets) of entries in `values`.
|
|
"""
|
|
return concat(combinations(values, i) for i in range(len(values) + 1))
|
|
|
|
|
|
def to_series(knowledge_dates, earning_dates):
|
|
"""
|
|
Helper for converting a dict of strings to a Series of datetimes.
|
|
|
|
This is just for making the test cases more readable.
|
|
"""
|
|
return pd.Series(
|
|
index=pd.to_datetime(knowledge_dates),
|
|
data=pd.to_datetime(earning_dates),
|
|
)
|
|
|
|
|
|
def gen_calendars(start, stop, critical_dates):
|
|
"""
|
|
Generate calendars to use as inputs.
|
|
"""
|
|
all_dates = pd.date_range(start, stop, tz='utc')
|
|
for to_drop in map(list, powerset(critical_dates)):
|
|
# Have to yield tuples.
|
|
yield (all_dates.drop(to_drop),)
|
|
|
|
# Also test with the trading calendar.
|
|
trading_days = get_calendar("NYSE").all_days
|
|
yield (trading_days[trading_days.slice_indexer(start, stop)],)
|
|
|
|
|
|
@contextmanager
|
|
def temp_pipeline_engine(calendar, sids, random_seed, symbols=None):
|
|
"""
|
|
A contextManager that yields a SimplePipelineEngine holding a reference to
|
|
an AssetFinder generated via tmp_asset_finder.
|
|
|
|
Parameters
|
|
----------
|
|
calendar : pd.DatetimeIndex
|
|
Calendar to pass to the constructed PipelineEngine.
|
|
sids : iterable[int]
|
|
Sids to use for the temp asset finder.
|
|
random_seed : int
|
|
Integer used to seed instances of SeededRandomLoader.
|
|
symbols : iterable[str], optional
|
|
Symbols for constructed assets. Forwarded to make_simple_equity_info.
|
|
"""
|
|
equity_info = make_simple_equity_info(
|
|
sids=sids,
|
|
start_date=calendar[0],
|
|
end_date=calendar[-1],
|
|
symbols=symbols,
|
|
)
|
|
|
|
loader = make_seeded_random_loader(random_seed, calendar, sids)
|
|
get_loader = lambda column: loader
|
|
|
|
with tmp_asset_finder(equities=equity_info) as finder:
|
|
yield SimplePipelineEngine(get_loader, calendar, finder)
|
|
|
|
|
|
def parameter_space(__fail_fast=False, **params):
|
|
"""
|
|
Wrapper around subtest that allows passing keywords mapping names to
|
|
iterables of values.
|
|
|
|
The decorated test function will be called with the cross-product of all
|
|
possible inputs
|
|
|
|
Usage
|
|
-----
|
|
>>> from unittest import TestCase
|
|
>>> class SomeTestCase(TestCase):
|
|
... @parameter_space(x=[1, 2], y=[2, 3])
|
|
... def test_some_func(self, x, y):
|
|
... # Will be called with every possible combination of x and y.
|
|
... self.assertEqual(somefunc(x, y), expected_result(x, y))
|
|
|
|
See Also
|
|
--------
|
|
zipline.testing.subtest
|
|
"""
|
|
def decorator(f):
|
|
|
|
argspec = getargspec(f)
|
|
if argspec.varargs:
|
|
raise AssertionError("parameter_space() doesn't support *args")
|
|
if argspec.keywords:
|
|
raise AssertionError("parameter_space() doesn't support **kwargs")
|
|
if argspec.defaults:
|
|
raise AssertionError("parameter_space() doesn't support defaults.")
|
|
|
|
# Skip over implicit self.
|
|
argnames = argspec.args
|
|
if argnames[0] == 'self':
|
|
argnames = argnames[1:]
|
|
|
|
extra = set(params) - set(argnames)
|
|
if extra:
|
|
raise AssertionError(
|
|
"Keywords %s supplied to parameter_space() are "
|
|
"not in function signature." % extra
|
|
)
|
|
|
|
unspecified = set(argnames) - set(params)
|
|
if unspecified:
|
|
raise AssertionError(
|
|
"Function arguments %s were not "
|
|
"supplied to parameter_space()." % extra
|
|
)
|
|
|
|
param_sets = product(*(params[name] for name in argnames))
|
|
|
|
if __fail_fast:
|
|
@wraps(f)
|
|
def wrapped(self):
|
|
for args in param_sets:
|
|
f(self, *args)
|
|
return wrapped
|
|
else:
|
|
return subtest(param_sets, *argnames)(f)
|
|
|
|
return decorator
|
|
|
|
|
|
def create_empty_dividends_frame():
|
|
return pd.DataFrame(
|
|
np.array(
|
|
[],
|
|
dtype=[
|
|
('ex_date', 'datetime64[ns]'),
|
|
('pay_date', 'datetime64[ns]'),
|
|
('record_date', 'datetime64[ns]'),
|
|
('declared_date', 'datetime64[ns]'),
|
|
('amount', 'float64'),
|
|
('sid', 'int32'),
|
|
],
|
|
),
|
|
index=pd.DatetimeIndex([], tz='UTC'),
|
|
)
|
|
|
|
|
|
def create_empty_splits_mergers_frame():
|
|
return pd.DataFrame(
|
|
np.array(
|
|
[],
|
|
dtype=[
|
|
('effective_date', 'int64'),
|
|
('ratio', 'float64'),
|
|
('sid', 'int64'),
|
|
],
|
|
),
|
|
index=pd.DatetimeIndex([]),
|
|
)
|
|
|
|
|
|
def make_alternating_boolean_array(shape, first_value=True):
|
|
"""
|
|
Create a 2D numpy array with the given shape containing alternating values
|
|
of False, True, False, True,... along each row and each column.
|
|
|
|
Examples
|
|
--------
|
|
>>> make_alternating_boolean_array((4,4))
|
|
array([[ True, False, True, False],
|
|
[False, True, False, True],
|
|
[ True, False, True, False],
|
|
[False, True, False, True]], dtype=bool)
|
|
>>> make_alternating_boolean_array((4,3), first_value=False)
|
|
array([[False, True, False],
|
|
[ True, False, True],
|
|
[False, True, False],
|
|
[ True, False, True]], dtype=bool)
|
|
"""
|
|
if len(shape) != 2:
|
|
raise ValueError(
|
|
'Shape must be 2-dimensional. Given shape was {}'.format(shape)
|
|
)
|
|
alternating = np.empty(shape, dtype=np.bool)
|
|
for row in alternating:
|
|
row[::2] = first_value
|
|
row[1::2] = not(first_value)
|
|
first_value = not(first_value)
|
|
return alternating
|
|
|
|
|
|
def make_cascading_boolean_array(shape, first_value=True):
|
|
"""
|
|
Create a numpy array with the given shape containing cascading boolean
|
|
values, with `first_value` being the top-left value.
|
|
|
|
Examples
|
|
--------
|
|
>>> make_cascading_boolean_array((4,4))
|
|
array([[ True, True, True, False],
|
|
[ True, True, False, False],
|
|
[ True, False, False, False],
|
|
[False, False, False, False]], dtype=bool)
|
|
>>> make_cascading_boolean_array((4,2))
|
|
array([[ True, False],
|
|
[False, False],
|
|
[False, False],
|
|
[False, False]], dtype=bool)
|
|
>>> make_cascading_boolean_array((2,4))
|
|
array([[ True, True, True, False],
|
|
[ True, True, False, False]], dtype=bool)
|
|
"""
|
|
if len(shape) != 2:
|
|
raise ValueError(
|
|
'Shape must be 2-dimensional. Given shape was {}'.format(shape)
|
|
)
|
|
cascading = np.full(shape, not(first_value), dtype=np.bool)
|
|
ending_col = shape[1] - 1
|
|
for row in cascading:
|
|
if ending_col > 0:
|
|
row[:ending_col] = first_value
|
|
ending_col -= 1
|
|
else:
|
|
break
|
|
return cascading
|
|
|
|
|
|
@expect_dimensions(array=2)
|
|
def permute_rows(seed, array):
|
|
"""
|
|
Shuffle each row in ``array`` based on permutations generated by ``seed``.
|
|
|
|
Parameters
|
|
----------
|
|
seed : int
|
|
Seed for numpy.RandomState
|
|
array : np.ndarray[ndim=2]
|
|
Array over which to apply permutations.
|
|
"""
|
|
rand = np.random.RandomState(seed)
|
|
return np.apply_along_axis(rand.permutation, 1, array)
|
|
|
|
|
|
@nottest
|
|
def make_test_handler(testcase, *args, **kwargs):
|
|
"""
|
|
Returns a TestHandler which will be used by the given testcase. This
|
|
handler can be used to test log messages.
|
|
|
|
Parameters
|
|
----------
|
|
testcase: unittest.TestCase
|
|
The test class in which the log handler will be used.
|
|
*args, **kwargs
|
|
Forwarded to the new TestHandler object.
|
|
|
|
Returns
|
|
-------
|
|
handler: logbook.TestHandler
|
|
The handler to use for the test case.
|
|
"""
|
|
handler = TestHandler(*args, **kwargs)
|
|
testcase.addCleanup(handler.close)
|
|
return handler
|
|
|
|
|
|
def write_compressed(path, content):
|
|
"""
|
|
Write a compressed (gzipped) file to `path`.
|
|
"""
|
|
with gzip.open(path, 'wb') as f:
|
|
f.write(content)
|
|
|
|
|
|
def read_compressed(path):
|
|
"""
|
|
Write a compressed (gzipped) file from `path`.
|
|
"""
|
|
with gzip.open(path, 'rb') as f:
|
|
return f.read()
|
|
|
|
|
|
zipline_git_root = abspath(
|
|
join(realpath(dirname(__file__)), '..', '..'),
|
|
)
|
|
|
|
|
|
@nottest
|
|
def test_resource_path(*path_parts):
|
|
return os.path.join(zipline_git_root, 'tests', 'resources', *path_parts)
|
|
|
|
|
|
@contextmanager
|
|
def patch_os_environment(remove=None, **values):
|
|
"""
|
|
Context manager for patching the operating system environment.
|
|
"""
|
|
old_values = {}
|
|
remove = remove or []
|
|
for key in remove:
|
|
old_values[key] = os.environ.pop(key)
|
|
for key, value in values.iteritems():
|
|
old_values[key] = os.getenv(key)
|
|
os.environ[key] = value
|
|
try:
|
|
yield
|
|
finally:
|
|
for old_key, old_value in old_values.iteritems():
|
|
if old_value is None:
|
|
# Value was not present when we entered, so del it out if it's
|
|
# still present.
|
|
try:
|
|
del os.environ[key]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
# Restore the old value.
|
|
os.environ[old_key] = old_value
|
|
|
|
|
|
class tmp_dir(TempDirectory, object):
|
|
"""New style class that wrapper for TempDirectory in python 2.
|
|
"""
|
|
pass
|
|
|
|
|
|
class _TmpBarReader(with_metaclass(ABCMeta, tmp_dir)):
|
|
"""A helper for tmp_bcolz_equity_minute_bar_reader and
|
|
tmp_bcolz_equity_daily_bar_reader.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : dict[int -> pd.DataFrame]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
"""
|
|
@abstractproperty
|
|
def _reader_cls(self):
|
|
raise NotImplementedError('_reader')
|
|
|
|
@abstractmethod
|
|
def _write(self, env, days, path, data):
|
|
raise NotImplementedError('_write')
|
|
|
|
def __init__(self, env, days, data, path=None):
|
|
super(_TmpBarReader, self).__init__(path=path)
|
|
self._env = env
|
|
self._days = days
|
|
self._data = data
|
|
|
|
def __enter__(self):
|
|
tmpdir = super(_TmpBarReader, self).__enter__()
|
|
env = self._env
|
|
try:
|
|
self._write(
|
|
env,
|
|
self._days,
|
|
tmpdir.path,
|
|
self._data,
|
|
)
|
|
return self._reader_cls(tmpdir.path)
|
|
except:
|
|
self.__exit__(None, None, None)
|
|
raise
|
|
|
|
|
|
class tmp_bcolz_equity_minute_bar_reader(_TmpBarReader):
|
|
"""A temporary BcolzMinuteBarReader object.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : iterable[(int, pd.DataFrame)]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
|
|
See Also
|
|
--------
|
|
tmp_bcolz_equity_daily_bar_reader
|
|
"""
|
|
_reader_cls = BcolzMinuteBarReader
|
|
_write = staticmethod(write_bcolz_minute_data)
|
|
|
|
|
|
class tmp_bcolz_equity_daily_bar_reader(_TmpBarReader):
|
|
"""A temporary BcolzDailyBarReader object.
|
|
|
|
Parameters
|
|
----------
|
|
env : TradingEnvironment
|
|
The trading env.
|
|
days : pd.DatetimeIndex
|
|
The days to write for.
|
|
data : dict[int -> pd.DataFrame]
|
|
The data to write.
|
|
path : str, optional
|
|
The path to the directory to write the data into. If not given, this
|
|
will be a unique name.
|
|
|
|
See Also
|
|
--------
|
|
tmp_bcolz_equity_daily_bar_reader
|
|
"""
|
|
_reader_cls = BcolzDailyBarReader
|
|
|
|
@staticmethod
|
|
def _write(env, days, path, data):
|
|
BcolzDailyBarWriter(path, days).write(data)
|
|
|
|
|
|
@contextmanager
|
|
def patch_read_csv(url_map, module=pd, strict=False):
|
|
"""Patch pandas.read_csv to map lookups from url to another.
|
|
|
|
Parameters
|
|
----------
|
|
url_map : mapping[str or file-like object -> str or file-like object]
|
|
The mapping to use to redirect read_csv calls.
|
|
module : module, optional
|
|
The module to patch ``read_csv`` on. By default this is ``pandas``.
|
|
This should be set to another module if ``read_csv`` is early-bound
|
|
like ``from pandas import read_csv`` instead of late-bound like:
|
|
``import pandas as pd; pd.read_csv``.
|
|
strict : bool, optional
|
|
If true, then this will assert that ``read_csv`` is only called with
|
|
elements in the ``url_map``.
|
|
"""
|
|
read_csv = pd.read_csv
|
|
|
|
def patched_read_csv(filepath_or_buffer, *args, **kwargs):
|
|
if filepath_or_buffer in url_map:
|
|
return read_csv(url_map[filepath_or_buffer], *args, **kwargs)
|
|
elif not strict:
|
|
return read_csv(filepath_or_buffer, *args, **kwargs)
|
|
else:
|
|
raise AssertionError(
|
|
'attempted to call read_csv on %r which not in the url map' %
|
|
filepath_or_buffer,
|
|
)
|
|
|
|
with patch.object(module, 'read_csv', patched_read_csv):
|
|
yield
|
|
|
|
|
|
@curry
|
|
def ensure_doctest(f, name=None):
|
|
"""Ensure that an object gets doctested. This is useful for instances
|
|
of objects like curry or partial which are not discovered by default.
|
|
|
|
Parameters
|
|
----------
|
|
f : any
|
|
The thing to doctest.
|
|
name : str, optional
|
|
The name to use in the doctest function mapping. If this is None,
|
|
Then ``f.__name__`` will be used.
|
|
|
|
Returns
|
|
-------
|
|
f : any
|
|
``f`` unchanged.
|
|
"""
|
|
_getframe(2).f_globals.setdefault('__test__', {})[
|
|
f.__name__ if name is None else name
|
|
] = f
|
|
return f
|
|
|
|
|
|
####################################
|
|
# Shared factors for pipeline tests.
|
|
####################################
|
|
|
|
class AssetID(CustomFactor):
|
|
"""
|
|
CustomFactor that returns the AssetID of each asset.
|
|
|
|
Useful for providing a Factor that produces a different value for each
|
|
asset.
|
|
"""
|
|
window_length = 1
|
|
inputs = ()
|
|
|
|
def compute(self, today, assets, out):
|
|
out[:] = assets
|
|
|
|
|
|
class AssetIDPlusDay(CustomFactor):
|
|
window_length = 1
|
|
inputs = ()
|
|
|
|
def compute(self, today, assets, out):
|
|
out[:] = assets + today.day
|
|
|
|
|
|
class OpenPrice(CustomFactor):
|
|
window_length = 1
|
|
inputs = [USEquityPricing.open]
|
|
|
|
def compute(self, today, assets, out, open):
|
|
out[:] = open
|