Files
catalyst/zipline/testing/core.py
T
Andrew Daniels 37e6a48e99 ENH: Pass calendar instance to BcolzMinuteBarWriter (#1406)
* First pass.

* Improvements and fixes

- Update usages of BcolzMinuteBarWriter
- Updates with rebuilt example data
- Expose calendar from BcolzMinuteBarMetadata instead of calendar_name
- Keep market_opens and market_closes in metadata for compatibility

* Store start_session and end_session in minute bcolz metadata

- start_session replaces first_trading_day
- Add end_session to limit to correct days

* For last_available_dt, get last close from calendar to maintain tz

* Bumps version and handles earlier versionson read

* Rebuilt example data on python 3

* Indicate metadata fields that are deprecated
2016-08-18 15:41:26 -04:00

1521 lines
44 KiB
Python

from abc import ABCMeta, abstractmethod, abstractproperty
from contextlib import contextmanager
from functools import wraps
import gzip
from inspect import getargspec
from itertools import (
combinations,
count,
product,
)
import operator
import os
from os.path import abspath, dirname, join, realpath
import shutil
from sys import _getframe
import tempfile
from logbook import TestHandler
from mock import patch
from nose.tools import nottest
from numpy.testing import assert_allclose, assert_array_equal
import pandas as pd
from six import itervalues, iteritems, with_metaclass
from six.moves import filter, map
from sqlalchemy import create_engine
from testfixtures import TempDirectory
from toolz import concat, curry
from zipline.assets import AssetFinder, AssetDBWriter
from zipline.assets.synthetic import make_simple_equity_info
from zipline.data.data_portal import DataPortal
from zipline.data.minute_bars import (
BcolzMinuteBarReader,
BcolzMinuteBarWriter,
US_EQUITIES_MINUTES_PER_DAY
)
from zipline.data.us_equity_pricing import (
BcolzDailyBarReader,
BcolzDailyBarWriter,
SQLiteAdjustmentWriter,
)
from zipline.finance.trading import TradingEnvironment
from zipline.finance.order import ORDER_STATUS
from zipline.lib.labelarray import LabelArray
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline.factors import CustomFactor
from zipline.pipeline.loaders.testing import make_seeded_random_loader
from zipline.utils import security_list
from zipline.utils.calendars import get_calendar
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.numpy_utils import as_column
from zipline.utils.sentinel import sentinel
import numpy as np
from numpy import float64
EPOCH = pd.Timestamp(0, tz='UTC')
def seconds_to_timestamp(seconds):
return pd.Timestamp(seconds, unit='s', tz='UTC')
def to_utc(time_str):
"""Convert a string in US/Eastern time to UTC"""
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
def str_to_seconds(s):
"""
Convert a pandas-intelligible string to (integer) seconds since UTC.
>>> from pandas import Timestamp
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
1388534400.0
>>> str_to_seconds('2014-01-01')
1388534400
"""
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
def drain_zipline(test, zipline):
output = []
transaction_count = 0
msg_counter = 0
# start the simulation
for update in zipline:
msg_counter += 1
output.append(update)
if 'daily_perf' in update:
transaction_count += \
len(update['daily_perf']['transactions'])
return output, transaction_count
def check_algo_results(test,
results,
expected_transactions_count=None,
expected_order_count=None,
expected_positions_count=None,
sid=None):
if expected_transactions_count is not None:
txns = flatten_list(results["transactions"])
test.assertEqual(expected_transactions_count, len(txns))
if expected_positions_count is not None:
raise NotImplementedError
if expected_order_count is not None:
# de-dup orders on id, because orders are put back into perf packets
# whenever they a txn is filled
orders = set([order['id'] for order in
flatten_list(results["orders"])])
test.assertEqual(expected_order_count, len(orders))
def flatten_list(list):
return [item for sublist in list for item in sublist]
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
if 'expected_transactions' in test.zipline_test_config:
test.assertEqual(
test.zipline_test_config['expected_transactions'],
transaction_count
)
else:
test.assertEqual(
test.zipline_test_config['order_count'],
transaction_count
)
# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
# dicts.
closing_positions = output[-2]['daily_perf']['positions']
# confirm that all orders were filled.
# iterate over the output updates, overwriting
# orders when they are updated. Then check the status on all.
orders_by_id = {}
for update in output:
if 'daily_perf' in update:
if 'orders' in update['daily_perf']:
for order in update['daily_perf']['orders']:
orders_by_id[order['id']] = order
for order in itervalues(orders_by_id):
test.assertEqual(
order['status'],
ORDER_STATUS.FILLED,
"")
test.assertEqual(
len(closing_positions),
1,
"Portfolio should have one position."
)
sid = test.zipline_test_config['sid']
test.assertEqual(
closing_positions[0]['sid'],
sid,
"Portfolio should have one position in " + str(sid)
)
return output, transaction_count
@contextmanager
def security_list_copy():
old_dir = security_list.SECURITY_LISTS_DIR
new_dir = tempfile.mkdtemp()
try:
for subdir in os.listdir(old_dir):
shutil.copytree(os.path.join(old_dir, subdir),
os.path.join(new_dir, subdir))
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
patch.object(security_list, 'using_copy', True,
create=True):
yield
finally:
shutil.rmtree(new_dir, True)
def add_security_data(adds, deletes):
if not hasattr(security_list, 'using_copy'):
raise Exception('add_security_data must be used within '
'security_list_copy context')
directory = os.path.join(
security_list.SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/20150125"
)
if not os.path.exists(directory):
os.makedirs(directory)
del_path = os.path.join(directory, "delete")
with open(del_path, 'w') as f:
for sym in deletes:
f.write(sym)
f.write('\n')
add_path = os.path.join(directory, "add")
with open(add_path, 'w') as f:
for sym in adds:
f.write(sym)
f.write('\n')
def all_pairs_matching_predicate(values, pred):
"""
Return an iterator of all pairs, (v0, v1) from values such that
`pred(v0, v1) == True`
Parameters
----------
values : iterable
pred : function
Returns
-------
pairs_iterator : generator
Generator yielding pairs matching `pred`.
Examples
--------
>>> from zipline.testing import all_pairs_matching_predicate
>>> from operator import eq, lt
>>> list(all_pairs_matching_predicate(range(5), eq))
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
>>> list(all_pairs_matching_predicate("abcd", lt))
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
"""
return filter(lambda pair: pred(*pair), product(values, repeat=2))
def product_upper_triangle(values, include_diagonal=False):
"""
Return an iterator over pairs, (v0, v1), drawn from values.
If `include_diagonal` is True, returns all pairs such that v0 <= v1.
If `include_diagonal` is False, returns all pairs such that v0 < v1.
"""
return all_pairs_matching_predicate(
values,
operator.le if include_diagonal else operator.lt,
)
def all_subindices(index):
"""
Return all valid sub-indices of a pandas Index.
"""
return (
index[start:stop]
for start, stop in product_upper_triangle(range(len(index) + 1))
)
def chrange(start, stop):
"""
Construct an iterable of length-1 strings beginning with `start` and ending
with `stop`.
Parameters
----------
start : str
The first character.
stop : str
The last character.
Returns
-------
chars: iterable[str]
Iterable of strings beginning with start and ending with stop.
Example
-------
>>> chrange('A', 'C')
['A', 'B', 'C']
"""
return list(map(chr, range(ord(start), ord(stop) + 1)))
def make_trade_data_for_asset_info(dates,
asset_info,
price_start,
price_step_by_date,
price_step_by_sid,
volume_start,
volume_step_by_date,
volume_step_by_sid,
frequency,
writer=None):
"""
Convert the asset info dataframe into a dataframe of trade data for each
sid, and write to the writer if provided. Write NaNs for locations where
assets did not exist. Return a dict of the dataframes, keyed by sid.
"""
trade_data = {}
sids = asset_info.index
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
price_date_deltas = (np.arange(len(dates), dtype=float64) *
price_step_by_date)
prices = (price_sid_deltas + as_column(price_date_deltas)) + price_start
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
volume_date_deltas = np.arange(len(dates)) * volume_step_by_date
volumes = volume_sid_deltas + as_column(volume_date_deltas) + volume_start
for j, sid in enumerate(sids):
start_date, end_date = asset_info.loc[sid, ['start_date', 'end_date']]
# Normalize here so the we still generate non-NaN values on the minutes
# for an asset's last trading day.
for i, date in enumerate(dates.normalize()):
if not (start_date <= date <= end_date):
prices[i, j] = 0
volumes[i, j] = 0
df = pd.DataFrame(
{
"open": prices[:, j],
"high": prices[:, j],
"low": prices[:, j],
"close": prices[:, j],
"volume": volumes[:, j],
},
index=dates,
)
if writer:
writer.write_sid(sid, df)
trade_data[sid] = df
return trade_data
def check_allclose(actual,
desired,
rtol=1e-07,
atol=0,
err_msg='',
verbose=True):
"""
Wrapper around np.testing.assert_allclose that also verifies that inputs
are ndarrays.
See Also
--------
np.assert_allclose
"""
if type(actual) != type(desired):
raise AssertionError("%s != %s" % (type(actual), type(desired)))
return assert_allclose(
actual,
desired,
atol=atol,
rtol=rtol,
err_msg=err_msg,
verbose=verbose,
)
def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
"""
Wrapper around np.testing.assert_array_equal that also verifies that inputs
are ndarrays.
See Also
--------
np.assert_array_equal
"""
assert type(x) == type(y), "{x} != {y}".format(x=type(x), y=type(y))
assert x.dtype == y.dtype, "{x.dtype} != {y.dtype}".format(x=x, y=y)
if isinstance(x, LabelArray):
# Check that both arrays have missing values in the same locations...
assert_array_equal(
x.is_missing(),
y.is_missing(),
err_msg=err_msg,
verbose=verbose,
)
# ...then check the actual values as well.
x = x.as_string_array()
y = y.as_string_array()
return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose)
class UnexpectedAttributeAccess(Exception):
pass
class ExplodingObject(object):
"""
Object that will raise an exception on any attribute access.
Useful for verifying that an object is never touched during a
function/method call.
"""
def __getattribute__(self, name):
raise UnexpectedAttributeAccess(name)
def write_minute_data(trading_calendar, tempdir, minutes, sids):
first_session = trading_calendar.minute_to_session_label(
minutes[0], direction="none"
)
last_session = trading_calendar.minute_to_session_label(
minutes[-1], direction="none"
)
sessions = trading_calendar.sessions_in_range(first_session, last_session)
write_bcolz_minute_data(
trading_calendar,
sessions,
tempdir.path,
create_minute_bar_data(minutes, sids),
)
return tempdir.path
def create_minute_bar_data(minutes, sids):
length = len(minutes)
for sid_idx, sid in enumerate(sids):
yield sid, pd.DataFrame(
{
'open': np.arange(length) + 10 + sid_idx,
'high': np.arange(length) + 15 + sid_idx,
'low': np.arange(length) + 8 + sid_idx,
'close': np.arange(length) + 10 + sid_idx,
'volume': 100 + sid_idx,
},
index=minutes,
)
def create_daily_bar_data(sessions, sids):
length = len(sessions)
for sid_idx, sid in enumerate(sids):
yield sid, pd.DataFrame(
{
"open": (np.array(range(10, 10 + length)) + sid_idx),
"high": (np.array(range(15, 15 + length)) + sid_idx),
"low": (np.array(range(8, 8 + length)) + sid_idx),
"close": (np.array(range(10, 10 + length)) + sid_idx),
"volume": np.array(range(100, 100 + length)) + sid_idx,
"day": [session.value for session in sessions]
},
index=sessions,
)
def write_daily_data(tempdir, sim_params, sids, trading_calendar):
path = os.path.join(tempdir.path, "testdaily.bcolz")
BcolzDailyBarWriter(path, trading_calendar,
sim_params.start_session,
sim_params.end_session).write(
create_daily_bar_data(sim_params.sessions, sids),
)
return path
def create_data_portal(asset_finder, tempdir, sim_params, sids,
trading_calendar, adjustment_reader=None):
if sim_params.data_frequency == "daily":
daily_path = write_daily_data(tempdir, sim_params, sids,
trading_calendar)
equity_daily_reader = BcolzDailyBarReader(daily_path)
return DataPortal(
asset_finder, trading_calendar,
first_trading_day=equity_daily_reader.first_trading_day,
equity_daily_reader=equity_daily_reader,
adjustment_reader=adjustment_reader
)
else:
minutes = trading_calendar.minutes_in_range(
sim_params.first_open,
sim_params.last_close
)
minute_path = write_minute_data(trading_calendar, tempdir, minutes,
sids)
equity_minute_reader = BcolzMinuteBarReader(minute_path)
return DataPortal(
asset_finder, trading_calendar,
first_trading_day=equity_minute_reader.first_trading_day,
equity_minute_reader=equity_minute_reader,
adjustment_reader=adjustment_reader
)
def write_bcolz_minute_data(trading_calendar, days, path, data):
BcolzMinuteBarWriter(
path,
trading_calendar,
days[0],
days[-1],
US_EQUITIES_MINUTES_PER_DAY
).write(data)
def create_minute_df_for_asset(trading_calendar,
start_dt,
end_dt,
interval=1,
start_val=1,
minute_blacklist=None):
asset_minutes = trading_calendar.minutes_for_sessions_in_range(
start_dt, end_dt
)
minutes_count = len(asset_minutes)
minutes_arr = np.array(range(start_val, start_val + minutes_count))
df = pd.DataFrame(
{
"open": minutes_arr + 1,
"high": minutes_arr + 2,
"low": minutes_arr - 1,
"close": minutes_arr,
"volume": 100 * minutes_arr,
},
index=asset_minutes,
)
if interval > 1:
counter = 0
while counter < len(minutes_arr):
df[counter:(counter + interval - 1)] = 0
counter += interval
if minute_blacklist is not None:
for minute in minute_blacklist:
df.loc[minute] = 0
return df
def create_daily_df_for_asset(trading_calendar, start_day, end_day,
interval=1):
days = trading_calendar.sessions_in_range(start_day, end_day)
days_count = len(days)
days_arr = np.arange(days_count) + 2
df = pd.DataFrame(
{
"open": days_arr + 1,
"high": days_arr + 2,
"low": days_arr - 1,
"close": days_arr,
"volume": days_arr * 100,
},
index=days,
)
if interval > 1:
# only keep every 'interval' rows
for idx, _ in enumerate(days_arr):
if (idx + 1) % interval != 0:
df["open"].iloc[idx] = 0
df["high"].iloc[idx] = 0
df["low"].iloc[idx] = 0
df["close"].iloc[idx] = 0
df["volume"].iloc[idx] = 0
return df
def trades_by_sid_to_dfs(trades_by_sid, index):
for sidint, trades in iteritems(trades_by_sid):
opens = []
highs = []
lows = []
closes = []
volumes = []
for trade in trades:
opens.append(trade["open_price"])
highs.append(trade["high"])
lows.append(trade["low"])
closes.append(trade["close_price"])
volumes.append(trade["volume"])
yield sidint, pd.DataFrame(
{
"open": opens,
"high": highs,
"low": lows,
"close": closes,
"volume": volumes,
},
index=index,
)
def create_data_portal_from_trade_history(asset_finder, trading_calendar,
tempdir, sim_params, trades_by_sid):
if sim_params.data_frequency == "daily":
path = os.path.join(tempdir.path, "testdaily.bcolz")
writer = BcolzDailyBarWriter(
path, trading_calendar,
sim_params.start_session,
sim_params.end_session
)
writer.write(
trades_by_sid_to_dfs(trades_by_sid, sim_params.sessions),
)
equity_daily_reader = BcolzDailyBarReader(path)
return DataPortal(
asset_finder, trading_calendar,
first_trading_day=equity_daily_reader.first_trading_day,
equity_daily_reader=equity_daily_reader,
)
else:
minutes = trading_calendar.minutes_in_range(
sim_params.first_open,
sim_params.last_close
)
length = len(minutes)
assets = {}
for sidint, trades in iteritems(trades_by_sid):
opens = np.zeros(length)
highs = np.zeros(length)
lows = np.zeros(length)
closes = np.zeros(length)
volumes = np.zeros(length)
for trade in trades:
# put them in the right place
idx = minutes.searchsorted(trade.dt)
opens[idx] = trade.open_price * 1000
highs[idx] = trade.high * 1000
lows[idx] = trade.low * 1000
closes[idx] = trade.close_price * 1000
volumes[idx] = trade.volume
assets[sidint] = pd.DataFrame({
"open": opens,
"high": highs,
"low": lows,
"close": closes,
"volume": volumes,
"dt": minutes
}).set_index("dt")
write_bcolz_minute_data(
trading_calendar,
sim_params.sessions,
tempdir.path,
assets
)
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
return DataPortal(
asset_finder, trading_calendar,
first_trading_day=equity_minute_reader.first_trading_day,
equity_minute_reader=equity_minute_reader,
)
class FakeDataPortal(DataPortal):
def __init__(self, env=None, trading_calendar=None,
first_trading_day=None):
if env is None:
env = TradingEnvironment()
if trading_calendar is None:
trading_calendar = get_calendar("NYSE")
super(FakeDataPortal, self).__init__(env.asset_finder,
trading_calendar,
first_trading_day)
def get_spot_value(self, asset, field, dt, data_frequency):
if field == "volume":
return 100
else:
return 1.0
def get_history_window(self, assets, end_dt, bar_count, frequency, field,
ffill=True):
if frequency == "1d":
end_idx = \
self.trading_calendar.all_sessions.searchsorted(end_dt)
days = self.trading_calendar.all_sessions[
(end_idx - bar_count + 1):(end_idx + 1)
]
df = pd.DataFrame(
np.full((bar_count, len(assets)), 100),
index=days,
columns=assets
)
return df
class FetcherDataPortal(DataPortal):
"""
Mock dataportal that returns fake data for history and non-fetcher
spot value.
"""
def __init__(self, asset_finder, trading_calendar, first_trading_day=None):
super(FetcherDataPortal, self).__init__(asset_finder, trading_calendar,
first_trading_day)
def get_spot_value(self, asset, field, dt, data_frequency):
# if this is a fetcher field, exercise the regular code path
if self._is_extra_source(asset, field, self._augmented_sources_map):
return super(FetcherDataPortal, self).get_spot_value(
asset, field, dt, data_frequency)
# otherwise just return a fixed value
return int(asset)
def _get_daily_window_for_sid(self, asset, field, days_in_window,
extra_slot=True):
return np.arange(days_in_window, dtype=np.float64)
def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
return np.arange(minutes_for_window, dtype=np.float64)
class tmp_assets_db(object):
"""Create a temporary assets sqlite database.
This is meant to be used as a context manager.
Parameters
----------
url : string
The URL for the database connection.
**frames
The frames to pass to the AssetDBWriter.
By default this maps equities:
('A', 'B', 'C') -> map(ord, 'ABC')
See Also
--------
empty_assets_db
tmp_asset_finder
"""
_default_equities = sentinel('_default_equities')
def __init__(self,
url='sqlite:///:memory:',
equities=_default_equities,
**frames):
self._url = url
self._eng = None
if equities is self._default_equities:
equities = make_simple_equity_info(
list(map(ord, 'ABC')),
pd.Timestamp(0),
pd.Timestamp('2015'),
)
frames['equities'] = equities
self._frames = frames
self._eng = None # set in enter and exit
def __enter__(self):
self._eng = eng = create_engine(self._url)
AssetDBWriter(eng).write(**self._frames)
return eng
def __exit__(self, *excinfo):
assert self._eng is not None, '_eng was not set in __enter__'
self._eng.dispose()
self._eng = None
def empty_assets_db():
"""Context manager for creating an empty assets db.
See Also
--------
tmp_assets_db
"""
return tmp_assets_db(equities=None)
class tmp_asset_finder(tmp_assets_db):
"""Create a temporary asset finder using an in memory sqlite db.
Parameters
----------
url : string
The URL for the database connection.
finder_cls : type, optional
The type of asset finder to create from the assets db.
**frames
Forwarded to ``tmp_assets_db``.
See Also
--------
tmp_assets_db
"""
def __init__(self,
url='sqlite:///:memory:',
finder_cls=AssetFinder,
**frames):
self._finder_cls = finder_cls
super(tmp_asset_finder, self).__init__(url=url, **frames)
def __enter__(self):
return self._finder_cls(super(tmp_asset_finder, self).__enter__())
def empty_asset_finder():
"""Context manager for creating an empty asset finder.
See Also
--------
empty_assets_db
tmp_assets_db
tmp_asset_finder
"""
return tmp_asset_finder(equities=None)
class tmp_trading_env(tmp_asset_finder):
"""Create a temporary trading environment.
Parameters
----------
finder_cls : type, optional
The type of asset finder to create from the assets db.
**frames
Forwarded to ``tmp_assets_db``.
See Also
--------
empty_trading_env
tmp_asset_finder
"""
def __enter__(self):
return TradingEnvironment(
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
)
def empty_trading_env():
return tmp_trading_env(equities=None)
class SubTestFailures(AssertionError):
def __init__(self, *failures):
self.failures = failures
def __str__(self):
return 'failures:\n %s' % '\n '.join(
'\n '.join((
', '.join('%s=%r' % item for item in scope.items()),
'%s: %s' % (type(exc).__name__, exc),
)) for scope, exc in self.failures,
)
@nottest
def subtest(iterator, *_names):
"""
Construct a subtest in a unittest.
Consider using ``zipline.testing.parameter_space`` when subtests
are constructed over a single input or over the cross-product of multiple
inputs.
``subtest`` works by decorating a function as a subtest. The decorated
function will be run by iterating over the ``iterator`` and *unpacking the
values into the function. If any of the runs fail, the result will be put
into a set and the rest of the tests will be run. Finally, if any failed,
all of the results will be dumped as one failure.
Parameters
----------
iterator : iterable[iterable]
The iterator of arguments to pass to the function.
*name : iterator[str]
The names to use for each element of ``iterator``. These will be used
to print the scope when a test fails. If not provided, it will use the
integer index of the value as the name.
Examples
--------
::
class MyTest(TestCase):
def test_thing(self):
# Example usage inside another test.
@subtest(([n] for n in range(100000)), 'n')
def subtest(n):
self.assertEqual(n % 2, 0, 'n was not even')
subtest()
@subtest(([n] for n in range(100000)), 'n')
def test_decorated_function(self, n):
# Example usage to parameterize an entire function.
self.assertEqual(n % 2, 1, 'n was not odd')
Notes
-----
We use this when we:
* Will never want to run each parameter individually.
* Have a large parameter space we are testing
(see tests/utils/test_events.py).
``nose_parameterized.expand`` will create a test for each parameter
combination which bloats the test output and makes the travis pages slow.
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
nose2 do not support ``addSubTest``.
See Also
--------
zipline.testing.parameter_space
"""
def dec(f):
@wraps(f)
def wrapped(*args, **kwargs):
names = _names
failures = []
for scope in iterator:
scope = tuple(scope)
try:
f(*args + scope, **kwargs)
except Exception as e:
if not names:
names = count()
failures.append((dict(zip(names, scope)), e))
if failures:
raise SubTestFailures(*failures)
return wrapped
return dec
class MockDailyBarReader(object):
def spot_price(self, col, sid, dt):
return 100
def create_mock_adjustment_data(splits=None, dividends=None, mergers=None):
if splits is None:
splits = create_empty_splits_mergers_frame()
elif not isinstance(splits, pd.DataFrame):
splits = pd.DataFrame(splits)
if mergers is None:
mergers = create_empty_splits_mergers_frame()
elif not isinstance(mergers, pd.DataFrame):
mergers = pd.DataFrame(mergers)
if dividends is None:
dividends = create_empty_dividends_frame()
elif not isinstance(dividends, pd.DataFrame):
dividends = pd.DataFrame(dividends)
return splits, mergers, dividends
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
mergers=None):
path = tempdir.getpath("test_adjustments.db")
SQLiteAdjustmentWriter(path, MockDailyBarReader(), days).write(
*create_mock_adjustment_data(splits, dividends, mergers)
)
return path
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
"""
Assert that two pandas Timestamp objects are the same.
Parameters
----------
left, right : pd.Timestamp
The values to compare.
compare_nat_equal : bool, optional
Whether to consider `NaT` values equal. Defaults to True.
msg : str, optional
A message to forward to `pd.util.testing.assert_equal`.
"""
if compare_nat_equal and left is pd.NaT and right is pd.NaT:
return
return pd.util.testing.assert_equal(left, right, msg=msg)
def powerset(values):
"""
Return the power set (i.e., the set of all subsets) of entries in `values`.
"""
return concat(combinations(values, i) for i in range(len(values) + 1))
def to_series(knowledge_dates, earning_dates):
"""
Helper for converting a dict of strings to a Series of datetimes.
This is just for making the test cases more readable.
"""
return pd.Series(
index=pd.to_datetime(knowledge_dates),
data=pd.to_datetime(earning_dates),
)
def gen_calendars(start, stop, critical_dates):
"""
Generate calendars to use as inputs.
"""
all_dates = pd.date_range(start, stop, tz='utc')
for to_drop in map(list, powerset(critical_dates)):
# Have to yield tuples.
yield (all_dates.drop(to_drop),)
# Also test with the trading calendar.
trading_days = get_calendar("NYSE").all_days
yield (trading_days[trading_days.slice_indexer(start, stop)],)
@contextmanager
def temp_pipeline_engine(calendar, sids, random_seed, symbols=None):
"""
A contextManager that yields a SimplePipelineEngine holding a reference to
an AssetFinder generated via tmp_asset_finder.
Parameters
----------
calendar : pd.DatetimeIndex
Calendar to pass to the constructed PipelineEngine.
sids : iterable[int]
Sids to use for the temp asset finder.
random_seed : int
Integer used to seed instances of SeededRandomLoader.
symbols : iterable[str], optional
Symbols for constructed assets. Forwarded to make_simple_equity_info.
"""
equity_info = make_simple_equity_info(
sids=sids,
start_date=calendar[0],
end_date=calendar[-1],
symbols=symbols,
)
loader = make_seeded_random_loader(random_seed, calendar, sids)
get_loader = lambda column: loader
with tmp_asset_finder(equities=equity_info) as finder:
yield SimplePipelineEngine(get_loader, calendar, finder)
def parameter_space(__fail_fast=False, **params):
"""
Wrapper around subtest that allows passing keywords mapping names to
iterables of values.
The decorated test function will be called with the cross-product of all
possible inputs
Usage
-----
>>> from unittest import TestCase
>>> class SomeTestCase(TestCase):
... @parameter_space(x=[1, 2], y=[2, 3])
... def test_some_func(self, x, y):
... # Will be called with every possible combination of x and y.
... self.assertEqual(somefunc(x, y), expected_result(x, y))
See Also
--------
zipline.testing.subtest
"""
def decorator(f):
argspec = getargspec(f)
if argspec.varargs:
raise AssertionError("parameter_space() doesn't support *args")
if argspec.keywords:
raise AssertionError("parameter_space() doesn't support **kwargs")
if argspec.defaults:
raise AssertionError("parameter_space() doesn't support defaults.")
# Skip over implicit self.
argnames = argspec.args
if argnames[0] == 'self':
argnames = argnames[1:]
extra = set(params) - set(argnames)
if extra:
raise AssertionError(
"Keywords %s supplied to parameter_space() are "
"not in function signature." % extra
)
unspecified = set(argnames) - set(params)
if unspecified:
raise AssertionError(
"Function arguments %s were not "
"supplied to parameter_space()." % extra
)
param_sets = product(*(params[name] for name in argnames))
if __fail_fast:
@wraps(f)
def wrapped(self):
for args in param_sets:
f(self, *args)
return wrapped
else:
return subtest(param_sets, *argnames)(f)
return decorator
def create_empty_dividends_frame():
return pd.DataFrame(
np.array(
[],
dtype=[
('ex_date', 'datetime64[ns]'),
('pay_date', 'datetime64[ns]'),
('record_date', 'datetime64[ns]'),
('declared_date', 'datetime64[ns]'),
('amount', 'float64'),
('sid', 'int32'),
],
),
index=pd.DatetimeIndex([], tz='UTC'),
)
def create_empty_splits_mergers_frame():
return pd.DataFrame(
np.array(
[],
dtype=[
('effective_date', 'int64'),
('ratio', 'float64'),
('sid', 'int64'),
],
),
index=pd.DatetimeIndex([]),
)
def make_alternating_boolean_array(shape, first_value=True):
"""
Create a 2D numpy array with the given shape containing alternating values
of False, True, False, True,... along each row and each column.
Examples
--------
>>> make_alternating_boolean_array((4,4))
array([[ True, False, True, False],
[False, True, False, True],
[ True, False, True, False],
[False, True, False, True]], dtype=bool)
>>> make_alternating_boolean_array((4,3), first_value=False)
array([[False, True, False],
[ True, False, True],
[False, True, False],
[ True, False, True]], dtype=bool)
"""
if len(shape) != 2:
raise ValueError(
'Shape must be 2-dimensional. Given shape was {}'.format(shape)
)
alternating = np.empty(shape, dtype=np.bool)
for row in alternating:
row[::2] = first_value
row[1::2] = not(first_value)
first_value = not(first_value)
return alternating
def make_cascading_boolean_array(shape, first_value=True):
"""
Create a numpy array with the given shape containing cascading boolean
values, with `first_value` being the top-left value.
Examples
--------
>>> make_cascading_boolean_array((4,4))
array([[ True, True, True, False],
[ True, True, False, False],
[ True, False, False, False],
[False, False, False, False]], dtype=bool)
>>> make_cascading_boolean_array((4,2))
array([[ True, False],
[False, False],
[False, False],
[False, False]], dtype=bool)
>>> make_cascading_boolean_array((2,4))
array([[ True, True, True, False],
[ True, True, False, False]], dtype=bool)
"""
if len(shape) != 2:
raise ValueError(
'Shape must be 2-dimensional. Given shape was {}'.format(shape)
)
cascading = np.full(shape, not(first_value), dtype=np.bool)
ending_col = shape[1] - 1
for row in cascading:
if ending_col > 0:
row[:ending_col] = first_value
ending_col -= 1
else:
break
return cascading
@expect_dimensions(array=2)
def permute_rows(seed, array):
"""
Shuffle each row in ``array`` based on permutations generated by ``seed``.
Parameters
----------
seed : int
Seed for numpy.RandomState
array : np.ndarray[ndim=2]
Array over which to apply permutations.
"""
rand = np.random.RandomState(seed)
return np.apply_along_axis(rand.permutation, 1, array)
@nottest
def make_test_handler(testcase, *args, **kwargs):
"""
Returns a TestHandler which will be used by the given testcase. This
handler can be used to test log messages.
Parameters
----------
testcase: unittest.TestCase
The test class in which the log handler will be used.
*args, **kwargs
Forwarded to the new TestHandler object.
Returns
-------
handler: logbook.TestHandler
The handler to use for the test case.
"""
handler = TestHandler(*args, **kwargs)
testcase.addCleanup(handler.close)
return handler
def write_compressed(path, content):
"""
Write a compressed (gzipped) file to `path`.
"""
with gzip.open(path, 'wb') as f:
f.write(content)
def read_compressed(path):
"""
Write a compressed (gzipped) file from `path`.
"""
with gzip.open(path, 'rb') as f:
return f.read()
zipline_git_root = abspath(
join(realpath(dirname(__file__)), '..', '..'),
)
@nottest
def test_resource_path(*path_parts):
return os.path.join(zipline_git_root, 'tests', 'resources', *path_parts)
@contextmanager
def patch_os_environment(remove=None, **values):
"""
Context manager for patching the operating system environment.
"""
old_values = {}
remove = remove or []
for key in remove:
old_values[key] = os.environ.pop(key)
for key, value in values.iteritems():
old_values[key] = os.getenv(key)
os.environ[key] = value
try:
yield
finally:
for old_key, old_value in old_values.iteritems():
if old_value is None:
# Value was not present when we entered, so del it out if it's
# still present.
try:
del os.environ[key]
except KeyError:
pass
else:
# Restore the old value.
os.environ[old_key] = old_value
class tmp_dir(TempDirectory, object):
"""New style class that wrapper for TempDirectory in python 2.
"""
pass
class _TmpBarReader(with_metaclass(ABCMeta, tmp_dir)):
"""A helper for tmp_bcolz_equity_minute_bar_reader and
tmp_bcolz_equity_daily_bar_reader.
Parameters
----------
env : TradingEnvironment
The trading env.
days : pd.DatetimeIndex
The days to write for.
data : dict[int -> pd.DataFrame]
The data to write.
path : str, optional
The path to the directory to write the data into. If not given, this
will be a unique name.
"""
@abstractproperty
def _reader_cls(self):
raise NotImplementedError('_reader')
@abstractmethod
def _write(self, env, days, path, data):
raise NotImplementedError('_write')
def __init__(self, env, days, data, path=None):
super(_TmpBarReader, self).__init__(path=path)
self._env = env
self._days = days
self._data = data
def __enter__(self):
tmpdir = super(_TmpBarReader, self).__enter__()
env = self._env
try:
self._write(
env,
self._days,
tmpdir.path,
self._data,
)
return self._reader_cls(tmpdir.path)
except:
self.__exit__(None, None, None)
raise
class tmp_bcolz_equity_minute_bar_reader(_TmpBarReader):
"""A temporary BcolzMinuteBarReader object.
Parameters
----------
env : TradingEnvironment
The trading env.
days : pd.DatetimeIndex
The days to write for.
data : iterable[(int, pd.DataFrame)]
The data to write.
path : str, optional
The path to the directory to write the data into. If not given, this
will be a unique name.
See Also
--------
tmp_bcolz_equity_daily_bar_reader
"""
_reader_cls = BcolzMinuteBarReader
_write = staticmethod(write_bcolz_minute_data)
class tmp_bcolz_equity_daily_bar_reader(_TmpBarReader):
"""A temporary BcolzDailyBarReader object.
Parameters
----------
env : TradingEnvironment
The trading env.
days : pd.DatetimeIndex
The days to write for.
data : dict[int -> pd.DataFrame]
The data to write.
path : str, optional
The path to the directory to write the data into. If not given, this
will be a unique name.
See Also
--------
tmp_bcolz_equity_daily_bar_reader
"""
_reader_cls = BcolzDailyBarReader
@staticmethod
def _write(env, days, path, data):
BcolzDailyBarWriter(path, days).write(data)
@contextmanager
def patch_read_csv(url_map, module=pd, strict=False):
"""Patch pandas.read_csv to map lookups from url to another.
Parameters
----------
url_map : mapping[str or file-like object -> str or file-like object]
The mapping to use to redirect read_csv calls.
module : module, optional
The module to patch ``read_csv`` on. By default this is ``pandas``.
This should be set to another module if ``read_csv`` is early-bound
like ``from pandas import read_csv`` instead of late-bound like:
``import pandas as pd; pd.read_csv``.
strict : bool, optional
If true, then this will assert that ``read_csv`` is only called with
elements in the ``url_map``.
"""
read_csv = pd.read_csv
def patched_read_csv(filepath_or_buffer, *args, **kwargs):
if filepath_or_buffer in url_map:
return read_csv(url_map[filepath_or_buffer], *args, **kwargs)
elif not strict:
return read_csv(filepath_or_buffer, *args, **kwargs)
else:
raise AssertionError(
'attempted to call read_csv on %r which not in the url map' %
filepath_or_buffer,
)
with patch.object(module, 'read_csv', patched_read_csv):
yield
@curry
def ensure_doctest(f, name=None):
"""Ensure that an object gets doctested. This is useful for instances
of objects like curry or partial which are not discovered by default.
Parameters
----------
f : any
The thing to doctest.
name : str, optional
The name to use in the doctest function mapping. If this is None,
Then ``f.__name__`` will be used.
Returns
-------
f : any
``f`` unchanged.
"""
_getframe(2).f_globals.setdefault('__test__', {})[
f.__name__ if name is None else name
] = f
return f
####################################
# Shared factors for pipeline tests.
####################################
class AssetID(CustomFactor):
"""
CustomFactor that returns the AssetID of each asset.
Useful for providing a Factor that produces a different value for each
asset.
"""
window_length = 1
inputs = ()
def compute(self, today, assets, out):
out[:] = assets
class AssetIDPlusDay(CustomFactor):
window_length = 1
inputs = ()
def compute(self, today, assets, out):
out[:] = assets + today.day
class OpenPrice(CustomFactor):
window_length = 1
inputs = [USEquityPricing.open]
def compute(self, today, assets, out, open):
out[:] = open