Merge pull request #1394 from quantopian/downsample

Add Generic Downsampling to Pipeline
This commit is contained in:
Scott Sanderson
2016-08-18 12:15:54 -04:00
committed by GitHub
32 changed files with 2118 additions and 491 deletions
+1 -1
View File
@@ -101,7 +101,7 @@ install:
- pip freeze | sort
test_script:
- nosetests
- nosetests -e zipline.utils.numpy_utils
- flake8 zipline tests
branches:
+41 -31
View File
@@ -1,24 +1,23 @@
"""
Base class for Pipeline API unittests.
Base class for Pipeline API unit tests.
"""
from functools import wraps
import numpy as np
from numpy import arange, prod
from pandas import date_range, Int64Index, DataFrame
from pandas import DataFrame, Timestamp
from six import iteritems
from zipline.assets.synthetic import make_simple_equity_info
from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline import TermGraph
from zipline.pipeline.term import AssetExists
from zipline.pipeline import ExecutionPlan
from zipline.pipeline.term import AssetExists, InputDates
from zipline.testing import (
check_arrays,
ExplodingObject,
tmp_asset_finder,
)
from zipline.testing.fixtures import (
WithTradingCalendars,
WithAssetFinder,
WithTradingSessions,
ZiplineTestCase,
)
@@ -53,32 +52,26 @@ def with_defaults(**default_funcs):
with_default_shape = with_defaults(shape=lambda self: self.default_shape)
class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
class BasePipelineTestCase(WithTradingSessions,
WithAssetFinder,
ZiplineTestCase):
START_DATE = Timestamp('2014', tz='UTC')
END_DATE = Timestamp('2014-12-31', tz='UTC')
ASSET_FINDER_EQUITY_SIDS = list(range(20))
@classmethod
def init_class_fixtures(cls):
super(BasePipelineTestCase, cls).init_class_fixtures()
cls.__calendar = date_range('2014', '2015',
freq=cls.trading_calendar.day)
cls.__assets = assets = Int64Index(arange(1, 20))
cls.__tmp_finder_ctx = tmp_asset_finder(
equities=make_simple_equity_info(
assets,
cls.__calendar[0],
cls.__calendar[-1],
)
)
cls.__finder = cls.__tmp_finder_ctx.__enter__()
cls.__mask = cls.__finder.lifetimes(
cls.__calendar[-30:],
cls.default_asset_exists_mask = cls.asset_finder.lifetimes(
cls.nyse_sessions[-30:],
include_start_date=False,
)
@property
def default_shape(self):
"""Default shape for methods that build test data."""
return self.__mask.shape
return self.default_asset_exists_mask.shape
def run_graph(self, graph, initial_workspace, mask=None):
"""
@@ -103,14 +96,17 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
"""
engine = SimplePipelineEngine(
lambda column: ExplodingObject(),
self.__calendar,
self.__finder,
self.nyse_sessions,
self.asset_finder,
)
if mask is None:
mask = self.__mask
mask = self.default_asset_exists_mask
dates, assets, mask_values = explode(mask)
initial_workspace.setdefault(AssetExists(), mask_values)
initial_workspace.setdefault(InputDates(), dates)
return engine.compute_chunk(
graph,
dates,
@@ -118,15 +114,29 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
initial_workspace,
)
def check_terms(self, terms, expected, initial_workspace, mask):
def check_terms(self,
terms,
expected,
initial_workspace,
mask,
check=check_arrays):
"""
Compile the given terms into a TermGraph, compute it with
initial_workspace, and compare the results with ``expected``.
"""
graph = TermGraph(terms)
start_date, end_date = mask.index[[0, -1]]
graph = ExecutionPlan(
terms,
all_dates=self.nyse_sessions,
start_date=start_date,
end_date=end_date,
)
results = self.run_graph(graph, initial_workspace, mask)
for key, (res, exp) in dzip_exact(results, expected).items():
check_arrays(res, exp)
check(res, exp)
return results
def build_mask(self, array):
"""
@@ -138,13 +148,13 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
array,
# Use the **last** N dates rather than the first N so that we have
# space for lookbacks.
index=self.__calendar[-ndates:],
columns=self.__assets[:nassets],
index=self.nyse_sessions[-ndates:],
columns=self.ASSET_FINDER_EQUITY_SIDS[:nassets],
dtype=bool,
)
@with_default_shape
def arange_data(self, shape, dtype=float):
def arange_data(self, shape, dtype=np.float64):
"""
Build a block of testing data from numpy.arange.
"""
+686
View File
@@ -0,0 +1,686 @@
"""
Tests for Downsampled Filters/Factors/Classifiers
"""
import pandas as pd
from pandas.util.testing import assert_frame_equal
from zipline.pipeline import (
Pipeline,
CustomFactor,
CustomFilter,
CustomClassifier,
)
from zipline.pipeline.data.testing import TestingDataSet
from zipline.pipeline.factors import SimpleMovingAverage
from zipline.pipeline.filters.smoothing import All
from zipline.testing import ZiplineTestCase, parameter_space
from zipline.testing.fixtures import (
WithTradingSessions,
WithSeededRandomPipelineEngine,
)
from zipline.utils.input_validation import _qualified_name
from zipline.utils.numpy_utils import int64_dtype
class NDaysAgoFactor(CustomFactor):
inputs = [TestingDataSet.float_col]
def compute(self, today, assets, out, floats):
out[:] = floats[0]
class NDaysAgoFilter(CustomFilter):
inputs = [TestingDataSet.bool_col]
def compute(self, today, assets, out, bools):
out[:] = bools[0]
class NDaysAgoClassifier(CustomClassifier):
inputs = [TestingDataSet.categorical_col]
dtype = TestingDataSet.categorical_col.dtype
def compute(self, today, assets, out, cats):
out[:] = cats[0]
class ComputeExtraRowsTestcase(WithTradingSessions, ZiplineTestCase):
DATA_MIN_DAY = pd.Timestamp('2012-06', tz='UTC')
DATA_MAX_DAY = pd.Timestamp('2015', tz='UTC')
TRADING_CALENDAR_STRS = ('NYSE',)
# Test with different window_lengths to ensure that window length is not
# used when calculating exra rows for the top-level term.
factor1 = TestingDataSet.float_col.latest
factor11 = NDaysAgoFactor(window_length=11)
factor91 = NDaysAgoFactor(window_length=91)
filter1 = TestingDataSet.bool_col.latest
filter11 = NDaysAgoFilter(window_length=11)
filter91 = NDaysAgoFilter(window_length=91)
classifier1 = TestingDataSet.categorical_col.latest
classifier11 = NDaysAgoClassifier(window_length=11)
classifier91 = NDaysAgoClassifier(window_length=91)
all_terms = [
factor1,
factor11,
factor91,
filter1,
filter11,
filter91,
classifier1,
classifier11,
classifier91,
]
@parameter_space(
calendar_name=TRADING_CALENDAR_STRS,
base_terms=[
(factor1, factor11, factor91),
(filter1, filter11, filter91),
(classifier1, classifier11, classifier91),
],
__fail_fast=True
)
def test_yearly(self, base_terms, calendar_name):
downsampled_terms = tuple(
t.downsample('year_start') for t in base_terms
)
all_terms = base_terms + downsampled_terms
all_sessions = self.trading_sessions[calendar_name]
end_session = all_sessions[-1]
years = all_sessions.year
sessions_in_2012 = all_sessions[years == 2012]
sessions_in_2013 = all_sessions[years == 2013]
sessions_in_2014 = all_sessions[years == 2014]
# Simulate requesting computation where the unaltered lookback would
# land exactly on the first date in 2014. We shouldn't request any
# additional rows for the regular terms or the downsampled terms.
for i in range(0, 30, 5):
start_session = sessions_in_2014[i]
self.check_extra_row_calculations(
all_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land on the second date in 2014. We should request one more extra
# row in the downsampled terms to push us back to the first date in
# 2014.
for i in range(0, 30, 5):
start_session = sessions_in_2014[i + 1]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i + 1,
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land on the last date of 2013. The downsampled terms should request
# enough extra rows to push us back to the start of 2013.
for i in range(0, 30, 5):
start_session = sessions_in_2014[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(sessions_in_2013),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
# Simulate requesting computation where the unaltered lookback would
# land on the last date of 2012. The downsampled terms should request
# enough extra rows to push us back to the first known date, which is
# in the middle of 2012
for i in range(0, 30, 5):
start_session = sessions_in_2013[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(sessions_in_2012),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
@parameter_space(
calendar_name=TRADING_CALENDAR_STRS,
base_terms=[
(factor1, factor11, factor91),
(filter1, filter11, filter91),
(classifier1, classifier11, classifier91),
],
__fail_fast=True
)
def test_quarterly(self, calendar_name, base_terms):
downsampled_terms = tuple(
t.downsample('quarter_start') for t in base_terms
)
all_terms = base_terms + downsampled_terms
# This region intersects with Q4 2013, Q1 2014, and Q2 2014.
tmp = self.trading_sessions[calendar_name]
all_sessions = tmp[tmp.slice_indexer('2013-12-15', '2014-04-30')]
end_session = all_sessions[-1]
months = all_sessions.month
Q4_2013 = all_sessions[months == 12]
Q1_2014 = all_sessions[(months == 1) | (months == 2) | (months == 3)]
Q2_2014 = all_sessions[months == 4]
# Simulate requesting computation where the unaltered lookback would
# land exactly on the first date in Q2 2014. We shouldn't request any
# additional rows for the regular terms or the downsampled terms.
for i in range(0, 15, 5):
start_session = Q2_2014[i]
self.check_extra_row_calculations(
all_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the second date in Q2 2014.
# The downsampled terms should request one more extra row.
for i in range(0, 15, 5):
start_session = Q2_2014[i + 1]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i + 1,
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the last date in Q1 2014. The downsampled terms
# should request enough extra rows to push us back to the first date of
# Q1 2014.
for i in range(0, 15, 5):
start_session = Q2_2014[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(Q1_2014),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the last date in Q4 2013. The downsampled terms
# should request enough extra rows to push us back to the first known
# date, which is in the middle of december 2013.
for i in range(0, 15, 5):
start_session = Q1_2014[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(Q4_2013),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
@parameter_space(
calendar_name=TRADING_CALENDAR_STRS,
base_terms=[
(factor1, factor11, factor91),
(filter1, filter11, filter91),
(classifier1, classifier11, classifier91),
],
__fail_fast=True
)
def test_monthly(self, calendar_name, base_terms):
downsampled_terms = tuple(
t.downsample('month_start') for t in base_terms
)
all_terms = base_terms + downsampled_terms
# This region intersects with Dec 2013, Jan 2014, and Feb 2014.
tmp = self.trading_sessions[calendar_name]
all_sessions = tmp[tmp.slice_indexer('2013-12-15', '2014-02-28')]
end_session = all_sessions[-1]
months = all_sessions.month
dec2013 = all_sessions[months == 12]
jan2014 = all_sessions[months == 1]
feb2014 = all_sessions[months == 2]
# Simulate requesting computation where the unaltered lookback would
# land exactly on the first date in feb 2014. We shouldn't request any
# additional rows for the regular terms or the downsampled terms.
for i in range(0, 10, 2):
start_session = feb2014[i]
self.check_extra_row_calculations(
all_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land on the second date in feb 2014. We should request one more
# extra row in the downsampled terms to push us back to the first date
# in 2014.
for i in range(0, 10, 2):
start_session = feb2014[i + 1]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i + 1,
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land on the last date of jan 2014. The downsampled terms should
# request enough extra rows to push us back to the start of jan 2014.
for i in range(0, 10, 2):
start_session = feb2014[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(jan2014),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
# Simulate requesting computation where the unaltered lookback would
# land on the last date of dec 2013. The downsampled terms should
# request enough extra rows to push us back to the first known date,
# which is in the middle of december 2013.
for i in range(0, 10, 2):
start_session = jan2014[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(dec2013),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
@parameter_space(
calendar_name=TRADING_CALENDAR_STRS,
base_terms=[
(factor1, factor11, factor91),
(filter1, filter11, filter91),
(classifier1, classifier11, classifier91),
],
__fail_fast=True
)
def test_weekly(self, calendar_name, base_terms):
downsampled_terms = tuple(
t.downsample('week_start') for t in base_terms
)
all_terms = base_terms + downsampled_terms
# December 2013
# Mo Tu We Th Fr Sa Su
# 1
# 2 3 4 5 6 7 8
# 9 10 11 12 13 14 15
# 16 17 18 19 20 21 22
# 23 24 25 26 27 28 29
# 30 31
# January 2014
# Mo Tu We Th Fr Sa Su
# 1 2 3 4 5
# 6 7 8 9 10 11 12
# 13 14 15 16 17 18 19
# 20 21 22 23 24 25 26
# 27 28 29 30 31
# This region intersects with the last full week of 2013, the week
# shared by 2013 and 2014, and the first full week of 2014.
tmp = self.trading_sessions[calendar_name]
all_sessions = tmp[tmp.slice_indexer('2013-12-27', '2014-01-12')]
end_session = all_sessions[-1]
week0 = all_sessions[
all_sessions.slice_indexer('2013-12-27', '2013-12-29')
]
week1 = all_sessions[
all_sessions.slice_indexer('2013-12-30', '2014-01-05')
]
week2 = all_sessions[
all_sessions.slice_indexer('2014-01-06', '2014-01-12')
]
# Simulate requesting computation where the unaltered lookback would
# land exactly on the first date in week 2. We shouldn't request any
# additional rows for the regular terms or the downsampled terms.
for i in range(3):
start_session = week2[i]
self.check_extra_row_calculations(
all_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the second date in week 2. The downsampled terms
# should request one more extra row.
for i in range(3):
start_session = week2[i + 1]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i + 1,
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i,
expected_extra_rows=i,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the last date in week 1. The downsampled terms
# should request enough extra rows to push us back to the first date of
# week 1.
for i in range(3):
start_session = week2[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(week1),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
# Simulate requesting computation where the unaltered lookback would
# land exactly on the last date in week0. The downsampled terms
# should request enough extra rows to push us back to the first known
# date, which is in the middle of december 2013.
for i in range(3):
start_session = week1[i]
self.check_extra_row_calculations(
downsampled_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + len(week0),
)
self.check_extra_row_calculations(
base_terms,
all_sessions,
start_session,
end_session,
min_extra_rows=i + 1,
expected_extra_rows=i + 1,
)
def check_extra_row_calculations(self,
terms,
all_sessions,
start_session,
end_session,
min_extra_rows,
expected_extra_rows):
"""
Check that each term in ``terms`` computes an expected number of extra
rows for the given parameters.
"""
for term in terms:
result = term.compute_extra_rows(
all_sessions,
start_session,
end_session,
min_extra_rows,
)
self.assertEqual(
result,
expected_extra_rows,
"Expected {} extra_rows from {}, but got {}.".format(
expected_extra_rows,
term,
result,
)
)
class DownsampledPipelineTestCase(WithSeededRandomPipelineEngine,
ZiplineTestCase):
# Extend into the last few days of 2013 to test year/quarter boundaries.
START_DATE = pd.Timestamp('2013-12-15', tz='UTC')
# Extend into the first few days of 2015 to test year/quarter boundaries.
END_DATE = pd.Timestamp('2015-01-06', tz='UTC')
ASSET_FINDER_EQUITY_SIDS = tuple(range(10))
def check_downsampled_term(self, term):
# June 2014
# Mo Tu We Th Fr Sa Su
# 1
# 2 3 4 5 6 7 8
# 9 10 11 12 13 14 15
# 16 17 18 19 20 21 22
# 23 24 25 26 27 28 29
# 30
all_sessions = self.nyse_sessions
compute_dates = all_sessions[
all_sessions.slice_indexer('2014-06-05', '2015-01-06')
]
start_date, end_date = compute_dates[[0, -1]]
pipe = Pipeline({
'year': term.downsample(frequency='year_start'),
'quarter': term.downsample(frequency='quarter_start'),
'month': term.downsample(frequency='month_start'),
'week': term.downsample(frequency='week_start'),
})
# Raw values for term, computed each day from 2014 to the end of the
# target period.
raw_term_results = self.run_pipeline(
Pipeline({'term': term}),
start_date=pd.Timestamp('2014-01-02', tz='UTC'),
end_date=pd.Timestamp('2015-01-06', tz='UTC'),
)['term'].unstack()
expected_results = {
'year': (raw_term_results
.groupby(pd.TimeGrouper('AS'))
.first()
.reindex(compute_dates, method='ffill')),
'quarter': (raw_term_results
.groupby(pd.TimeGrouper('QS'))
.first()
.reindex(compute_dates, method='ffill')),
'month': (raw_term_results
.groupby(pd.TimeGrouper('MS'))
.first()
.reindex(compute_dates, method='ffill')),
'week': (raw_term_results
.groupby(pd.TimeGrouper('W', label='left'))
.first()
.reindex(compute_dates, method='ffill')),
}
results = self.run_pipeline(pipe, start_date, end_date)
for frequency in expected_results:
result = results[frequency].unstack()
expected = expected_results[frequency]
assert_frame_equal(result, expected)
def test_downsample_windowed_factor(self):
self.check_downsampled_term(
SimpleMovingAverage(
inputs=[TestingDataSet.float_col],
window_length=5,
)
)
def test_downsample_non_windowed_factor(self):
sma = SimpleMovingAverage(
inputs=[TestingDataSet.float_col],
window_length=5,
)
self.check_downsampled_term(((sma + sma) / 2).rank())
def test_downsample_windowed_filter(self):
sma = SimpleMovingAverage(
inputs=[TestingDataSet.float_col],
window_length=5,
)
self.check_downsampled_term(All(inputs=[sma.top(4)], window_length=5))
def test_downsample_nonwindowed_filter(self):
sma = SimpleMovingAverage(
inputs=[TestingDataSet.float_col],
window_length=5,
)
self.check_downsampled_term(sma > 5)
def test_downsample_windowed_classifier(self):
class IntSumClassifier(CustomClassifier):
inputs = [TestingDataSet.float_col]
window_length = 8
dtype = int64_dtype
missing_value = -1
def compute(self, today, assets, out, floats):
out[:] = floats.sum(axis=0).astype(int) % 4
self.check_downsampled_term(IntSumClassifier())
def test_downsample_nonwindowed_classifier(self):
sma = SimpleMovingAverage(
inputs=[TestingDataSet.float_col],
window_length=5,
)
self.check_downsampled_term(sma.quantiles(5))
def test_errors_on_bad_downsample_frequency(self):
f = NDaysAgoFactor(window_length=3)
with self.assertRaises(ValueError) as e:
f.downsample('bad')
expected = (
"{}() expected a value in "
"('month_start', 'quarter_start', 'week_start', 'year_start') "
"for argument 'frequency', but got 'bad' instead."
).format(_qualified_name(f.downsample))
self.assertEqual(str(e.exception), expected)
+50 -1
View File
@@ -40,6 +40,7 @@ from six import iteritems, itervalues
from toolz import merge
from zipline.assets.synthetic import make_rotating_equity_info
from zipline.errors import NoFurtherDataError
from zipline.lib.adjustment import MULTIPLY
from zipline.lib.labelarray import LabelArray
from zipline.pipeline import CustomFactor, Pipeline
@@ -65,6 +66,7 @@ from zipline.pipeline.loaders.synthetic import (
expected_bar_values_2d,
)
from zipline.pipeline.sentinels import NotSpecified
from zipline.pipeline.term import InputDates
from zipline.testing import (
AssetID,
AssetIDPlusDay,
@@ -81,7 +83,7 @@ from zipline.testing.fixtures import (
ZiplineTestCase,
)
from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import bool_dtype
from zipline.utils.numpy_utils import bool_dtype, datetime64ns_dtype
class RollingSumDifference(CustomFactor):
@@ -206,6 +208,53 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
with self.assertRaisesRegexp(ValueError, msg):
engine.run_pipeline(p, self.dates[2], self.dates[1])
def test_fail_usefully_on_insufficient_data(self):
loader = self.loader
engine = SimplePipelineEngine(
lambda column: loader, self.dates, self.asset_finder,
)
class SomeFactor(CustomFactor):
inputs = [USEquityPricing.close]
window_length = 10
def compute(self, today, assets, out, closes):
pass
p = Pipeline(columns={'t': SomeFactor()})
# self.dates[9] is the earliest date we should be able to compute.
engine.run_pipeline(p, self.dates[9], self.dates[9])
# We shouldn't be able to compute dates[8], since we only know about 8
# prior dates, and we need a window length of 10.
with self.assertRaises(NoFurtherDataError):
engine.run_pipeline(p, self.dates[8], self.dates[8])
def test_input_dates_provided_by_default(self):
loader = self.loader
engine = SimplePipelineEngine(
lambda column: loader, self.dates, self.asset_finder,
)
class TestFactor(CustomFactor):
inputs = [InputDates(), USEquityPricing.close]
window_length = 10
dtype = datetime64ns_dtype
def compute(self, today, assets, out, dates, closes):
first, last = dates[[0, -1], 0]
assert last == today.asm8
assert len(dates) == len(closes) == self.window_length
out[:] = first
p = Pipeline(columns={'t': TestFactor()})
results = engine.run_pipeline(p, self.dates[9], self.dates[10])
# All results are the same, so just grab one column.
column = results.unstack().iloc[:, 0].values
check_arrays(column, self.dates[:2].values)
def test_same_day_pipeline(self):
loader = self.loader
engine = SimplePipelineEngine(
+50 -75
View File
@@ -26,7 +26,7 @@ from zipline.errors import UnknownRankMethod
from zipline.lib.labelarray import LabelArray
from zipline.lib.rank import masked_rankdata_2d
from zipline.lib.normalize import naive_grouped_rowwise_apply as grouped_apply
from zipline.pipeline import Classifier, Factor, Filter, TermGraph
from zipline.pipeline import Classifier, Factor, Filter
from zipline.pipeline.factors import (
Returns,
RSI,
@@ -37,7 +37,6 @@ from zipline.testing import (
parameter_space,
permute_rows,
)
from zipline.utils.functional import dzip_exact
from zipline.utils.numpy_utils import (
categorical_dtype,
datetime64ns_dtype,
@@ -123,20 +122,18 @@ class FactorTestCase(BasePipelineTestCase):
data = arange(25).reshape(5, 5)
data[eye(5, dtype=bool)] = custom_missing_value
graph = TermGraph(
self.check_terms(
{
'isnull': factor.isnull(),
'notnull': factor.notnull(),
}
)
results = self.run_graph(
graph,
},
{
'isnull': eye(5, dtype=bool),
'notnull': ~eye(5, dtype=bool),
},
initial_workspace={factor: data},
mask=self.build_mask(ones((5, 5))),
)
check_arrays(results['isnull'], eye(5, dtype=bool))
check_arrays(results['notnull'], ~eye(5, dtype=bool))
def test_isnull_datetime_dtype(self):
class DatetimeFactor(Factor):
@@ -149,20 +146,18 @@ class FactorTestCase(BasePipelineTestCase):
data = arange(25).reshape(5, 5).astype('datetime64[ns]')
data[eye(5, dtype=bool)] = NaTns
graph = TermGraph(
self.check_terms(
{
'isnull': factor.isnull(),
'notnull': factor.notnull(),
}
)
results = self.run_graph(
graph,
},
{
'isnull': eye(5, dtype=bool),
'notnull': ~eye(5, dtype=bool),
},
initial_workspace={factor: data},
mask=self.build_mask(ones((5, 5))),
)
check_arrays(results['isnull'], eye(5, dtype=bool))
check_arrays(results['notnull'], ~eye(5, dtype=bool))
@for_each_factor_dtype
def test_rank_ascending(self, name, factor_dtype):
@@ -206,14 +201,12 @@ class FactorTestCase(BasePipelineTestCase):
}
def check(terms):
graph = TermGraph(terms)
results = self.run_graph(
graph,
self.check_terms(
terms,
expected={name: expected_ranks[name] for name in terms},
initial_workspace={f: data},
mask=self.build_mask(ones((5, 5))),
)
for method in terms:
check_arrays(results[method], expected_ranks[method])
check({meth: f.rank(method=meth) for meth in expected_ranks})
check({
@@ -265,14 +258,12 @@ class FactorTestCase(BasePipelineTestCase):
}
def check(terms):
graph = TermGraph(terms)
results = self.run_graph(
graph,
self.check_terms(
terms,
expected={name: expected_ranks[name] for name in terms},
initial_workspace={f: data},
mask=self.build_mask(ones((5, 5))),
)
for method in terms:
check_arrays(results[method], expected_ranks[method])
check({
meth: f.rank(method=meth, ascending=False)
@@ -294,14 +285,12 @@ class FactorTestCase(BasePipelineTestCase):
mask_data = ~eye(5, dtype=bool)
initial_workspace = {f: data, Mask(): mask_data}
graph = TermGraph(
{
"ascending_nomask": f.rank(ascending=True),
"ascending_mask": f.rank(ascending=True, mask=Mask()),
"descending_nomask": f.rank(ascending=False),
"descending_mask": f.rank(ascending=False, mask=Mask()),
}
)
terms = {
"ascending_nomask": f.rank(ascending=True),
"ascending_mask": f.rank(ascending=True, mask=Mask()),
"descending_nomask": f.rank(ascending=False),
"descending_mask": f.rank(ascending=False, mask=Mask()),
}
expected = {
"ascending_nomask": array([[1., 3., 4., 5., 2.],
@@ -328,13 +317,12 @@ class FactorTestCase(BasePipelineTestCase):
[4., 3., 2., 1., nan]]),
}
results = self.run_graph(
graph,
self.check_terms(
terms,
expected,
initial_workspace,
mask=self.build_mask(ones((5, 5))),
)
for method in results:
check_arrays(expected[method], results[method])
@for_each_factor_dtype
def test_grouped_rank_ascending(self, name, factor_dtype=float64_dtype):
@@ -363,7 +351,7 @@ class FactorTestCase(BasePipelineTestCase):
missing_value=None,
)
expected_grouped_ranks = {
expected_ranks = {
'ordinal': array(
[[1., 1., 3., 2., 2.],
[1., 2., 3., 1., 2.],
@@ -402,9 +390,9 @@ class FactorTestCase(BasePipelineTestCase):
}
def check(terms):
graph = TermGraph(terms)
results = self.run_graph(
graph,
self.check_terms(
terms,
expected={name: expected_ranks[name] for name in terms},
initial_workspace={
f: data,
c: classifier_data,
@@ -413,25 +401,22 @@ class FactorTestCase(BasePipelineTestCase):
mask=self.build_mask(ones((5, 5))),
)
for method in terms:
check_arrays(results[method], expected_grouped_ranks[method])
# Not specifying the value of ascending param should default to True
check({
meth: f.rank(method=meth, groupby=c)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
check({
meth: f.rank(method=meth, groupby=str_c)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
check({
meth: f.rank(method=meth, groupby=c, ascending=True)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
check({
meth: f.rank(method=meth, groupby=str_c, ascending=True)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
# Not passing a method should default to ordinal
@@ -468,7 +453,7 @@ class FactorTestCase(BasePipelineTestCase):
missing_value=None,
)
expected_grouped_ranks = {
expected_ranks = {
'ordinal': array(
[[2., 2., 1., 1., 3.],
[2., 1., 1., 2., 3.],
@@ -507,9 +492,9 @@ class FactorTestCase(BasePipelineTestCase):
}
def check(terms):
graph = TermGraph(terms)
results = self.run_graph(
graph,
self.check_terms(
terms,
expected={name: expected_ranks[name] for name in terms},
initial_workspace={
f: data,
c: classifier_data,
@@ -518,16 +503,13 @@ class FactorTestCase(BasePipelineTestCase):
mask=self.build_mask(ones((5, 5))),
)
for method in terms:
check_arrays(results[method], expected_grouped_ranks[method])
check({
meth: f.rank(method=meth, groupby=c, ascending=False)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
check({
meth: f.rank(method=meth, groupby=str_c, ascending=False)
for meth in expected_grouped_ranks
for meth in expected_ranks
})
# Not passing a method should default to ordinal
@@ -707,9 +689,9 @@ class FactorTestCase(BasePipelineTestCase):
expected['grouped_str'] = expected['grouped']
expected['grouped_masked_str'] = expected['grouped_masked']
graph = TermGraph(terms)
results = self.run_graph(
graph,
self.check_terms(
terms,
expected,
initial_workspace={
f: factor_data,
c: classifier_data,
@@ -717,20 +699,13 @@ class FactorTestCase(BasePipelineTestCase):
m: filter_data,
},
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
# The hand-computed values aren't very precise (in particular,
# we truncate repeating decimals at 3 places) This is just
# asserting that the example isn't misleading by being totally
# wrong.
check=partial(check_allclose, atol=0.001),
)
for key, (res, exp) in dzip_exact(results, expected).items():
check_allclose(
res,
exp,
# The hand-computed values aren't very precise (in particular,
# we truncate repeating decimals at 3 places) This is just
# asserting that the example isn't misleading by being totally
# wrong.
atol=0.001,
err_msg="Mismatch for %r" % key
)
@parameter_space(
seed_value=range(1, 2),
normalizer_name_and_func=[
+156 -179
View File
@@ -12,7 +12,6 @@ from numpy import (
array,
eye,
float64,
full_like,
full,
inf,
isfinite,
@@ -27,11 +26,11 @@ from numpy import (
from numpy.random import randn, seed as random_seed
from zipline.errors import BadPercentileBounds
from zipline.pipeline import Filter, Factor, TermGraph
from zipline.pipeline import Filter, Factor
from zipline.pipeline.classifiers import Classifier
from zipline.pipeline.factors import CustomFactor
from zipline.pipeline.filters import All, Any, AtLeastN
from zipline.testing import check_arrays, parameter_space, permute_rows
from zipline.testing import parameter_space, permute_rows
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
from .base import BasePipelineTestCase, with_default_shape
@@ -127,7 +126,6 @@ class FilterTestCase(BasePipelineTestCase):
nan_data[:, 0] = nan
mask = Mask()
workspace = {self.f: data, mask: mask_data}
methods = ['top', 'bottom']
counts = 2, 3, 10
@@ -136,18 +134,6 @@ class FilterTestCase(BasePipelineTestCase):
def termname(method, count, masked):
return '_'.join([method, str(count), 'mask' if masked else ''])
# Add a term for each permutation of top/bottom, count, and
# mask/no_mask.
terms = {}
for method, count, masked in term_combos:
kwargs = {'N': count}
if masked:
kwargs['mask'] = mask
term = getattr(self.f, method)(**kwargs)
terms[termname(method, count, masked)] = term
results = self.run_graph(TermGraph(terms), initial_workspace=workspace)
def expected_result(method, count, masked):
# Ranking with a mask is equivalent to ranking with nans applied on
# the masked values.
@@ -158,72 +144,55 @@ class FilterTestCase(BasePipelineTestCase):
elif method == 'bottom':
return rowwise_rank(to_rank) < count
# Add a term for each permutation of top/bottom, count, and
# mask/no_mask.
terms = {}
expected = {}
for method, count, masked in term_combos:
result = results[termname(method, count, masked)]
kwargs = {'N': count}
if masked:
kwargs['mask'] = mask
term = getattr(self.f, method)(**kwargs)
name = termname(method, count, masked)
terms[name] = term
expected[name] = expected_result(method, count, masked)
# Check that `min(c, num_assets)` assets passed each day.
passed_per_day = result.sum(axis=1)
check_arrays(
passed_per_day,
full_like(passed_per_day, min(count, data.shape[1])),
)
expected = expected_result(method, count, masked)
check_arrays(result, expected)
def test_bottom(self):
counts = 2, 3, 10
data = self.randn_data(seed=5) # Arbitrary seed choice.
results = self.run_graph(
TermGraph(
{'bottom_' + str(c): self.f.bottom(c) for c in counts}
),
initial_workspace={self.f: data},
self.check_terms(
terms,
expected,
initial_workspace={self.f: data, mask: mask_data},
mask=self.build_mask(self.ones_mask()),
)
for c in counts:
result = results['bottom_' + str(c)]
# Check that `min(c, num_assets)` assets passed each day.
passed_per_day = result.sum(axis=1)
check_arrays(
passed_per_day,
full_like(passed_per_day, min(c, data.shape[1])),
)
# Check that the bottom `c` assets passed.
expected = rowwise_rank(data) < c
check_arrays(result, expected)
def test_percentile_between(self):
quintiles = range(5)
filter_names = ['pct_' + str(q) for q in quintiles]
iter_quintiles = zip(filter_names, quintiles)
graph = TermGraph(
{
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
for name, q in zip(filter_names, quintiles)
}
)
iter_quintiles = list(zip(filter_names, quintiles))
terms = {
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
for name, q in iter_quintiles
}
# Test with 5 columns and no NaNs.
eye5 = eye(5, dtype=float64)
results = self.run_graph(
graph,
initial_workspace={self.f: eye5},
mask=self.build_mask(ones((5, 5))),
)
expected = {}
for name, quintile in iter_quintiles:
result = results[name]
if quintile < 4:
# There are four 0s and one 1 in each row, so the first 4
# quintiles should be all the locations with zeros in the input
# array.
check_arrays(result, ~eye5.astype(bool))
expected[name] = ~eye5.astype(bool)
else:
# The top quintile should match the sole 1 in each row.
check_arrays(result, eye5.astype(bool))
expected[name] = eye5.astype(bool)
self.check_terms(
terms=terms,
expected=expected,
initial_workspace={self.f: eye5},
mask=self.build_mask(ones((5, 5))),
)
# Test with 6 columns, no NaNs, and one masked entry per day.
eye6 = eye(6, dtype=float64)
@@ -233,41 +202,44 @@ class FilterTestCase(BasePipelineTestCase):
[1, 1, 0, 1, 1, 1],
[1, 1, 1, 0, 1, 1],
[1, 1, 1, 1, 0, 1]], dtype=bool)
results = self.run_graph(
graph,
initial_workspace={self.f: eye6},
mask=self.build_mask(mask)
)
expected = {}
for name, quintile in iter_quintiles:
result = results[name]
if quintile < 4:
# Should keep all values that were 0 in the base data and were
# 1 in the mask.
check_arrays(result, mask & (~eye6.astype(bool))),
expected[name] = mask & ~eye6.astype(bool)
else:
# Should keep all the 1s in the base data.
check_arrays(result, eye6.astype(bool))
# The top quintile should match the sole 1 in each row.
expected[name] = eye6.astype(bool)
self.check_terms(
terms=terms,
expected=expected,
initial_workspace={self.f: eye6},
mask=self.build_mask(mask),
)
# Test with 6 columns, no mask, and one NaN per day. Should have the
# same outcome as if we had masked the NaNs.
# In particular, the NaNs should never pass any filters.
eye6_withnans = eye6.copy()
putmask(eye6_withnans, ~mask, nan)
results = self.run_graph(
graph,
initial_workspace={self.f: eye6},
mask=self.build_mask(mask)
)
expected = {}
for name, quintile in iter_quintiles:
result = results[name]
if quintile < 4:
# Should keep all values that were 0 in the base data and were
# 1 in the mask.
check_arrays(result, mask & (~eye6.astype(bool))),
expected[name] = mask & (~eye6.astype(bool))
else:
# Should keep all the 1s in the base data.
check_arrays(result, eye6.astype(bool))
expected[name] = eye6.astype(bool)
self.check_terms(
terms,
expected,
initial_workspace={self.f: eye6},
mask=self.build_mask(mask),
)
def test_percentile_nasty_partitions(self):
# Test percentile with nasty partitions: divide up 5 assets into
@@ -281,27 +253,26 @@ class FilterTestCase(BasePipelineTestCase):
quartiles = range(4)
filter_names = ['pct_' + str(q) for q in quartiles]
graph = TermGraph(
{
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
for name, q in zip(filter_names, quartiles)
}
)
results = self.run_graph(
graph,
initial_workspace={self.f: data},
mask=self.build_mask(ones((5, 5))),
)
terms = {
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
for name, q in zip(filter_names, quartiles)
}
expected = {}
for name, quartile in zip(filter_names, quartiles):
result = results[name]
lower = quartile * 25.0
upper = (quartile + 1) * 25.0
expected = and_(
expected[name] = and_(
nanpercentile(data, lower, axis=1, keepdims=True) <= data,
data <= nanpercentile(data, upper, axis=1, keepdims=True),
)
check_arrays(result, expected)
self.check_terms(
terms,
expected,
initial_workspace={self.f: data},
mask=self.build_mask(ones((5, 5))),
)
def test_percentile_after_mask(self):
f_input = eye(5)
@@ -312,77 +283,79 @@ class FilterTestCase(BasePipelineTestCase):
without_mask = self.g.percentile_between(80, 100)
with_mask = self.g.percentile_between(80, 100, mask=custom_mask)
graph = TermGraph(
{
'custom_mask': custom_mask,
'without': without_mask,
'with': with_mask,
}
)
terms = {
'mask': custom_mask,
'without_mask': without_mask,
'with_mask': with_mask,
}
expected = {
# Mask that accepts everything except the diagonal.
'mask': ~eye(5, dtype=bool),
# Second should pass the largest value each day. Each row is
# strictly increasing, so we always select the last value.
'without_mask': array(
[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]],
dtype=bool,
),
# With a mask, we should remove the diagonal as an option before
# computing percentiles. On the last day, we should get the
# second-largest value, rather than the largest.
'with_mask': array(
[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]], # Different from with!
dtype=bool,
),
}
results = self.run_graph(
graph,
self.check_terms(
terms,
expected,
initial_workspace={self.f: f_input, self.g: g_input},
mask=initial_mask,
)
# First should pass everything but the diagonal.
check_arrays(results['custom_mask'], ~eye(5, dtype=bool))
# Second should pass the largest value each day. Each row is strictly
# increasing, so we always select the last value.
expected_without = array(
[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]],
dtype=bool,
)
check_arrays(results['without'], expected_without)
# When sequencing, we should remove the diagonal as an option before
# computing percentiles. On the last day, we should get the
# second-largest value, rather than the largest.
expected_with = array(
[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]], # Different from previous!
dtype=bool,
)
check_arrays(results['with'], expected_with)
def test_isnan(self):
data = self.randn_data(seed=10)
diag = eye(*data.shape, dtype=bool)
data[diag] = nan
results = self.run_graph(
TermGraph({
self.check_terms(
terms={
'isnan': self.f.isnan(),
'isnull': self.f.isnull(),
}),
},
expected={
'isnan': diag,
'isnull': diag,
},
initial_workspace={self.f: data},
mask=self.build_mask(self.ones_mask()),
)
check_arrays(results['isnan'], diag)
check_arrays(results['isnull'], diag)
def test_notnan(self):
data = self.randn_data(seed=10)
diag = eye(*data.shape, dtype=bool)
data[diag] = nan
results = self.run_graph(
TermGraph({
self.check_terms(
terms={
'notnan': self.f.notnan(),
'notnull': self.f.notnull(),
}),
},
expected={
'notnan': ~diag,
'notnull': ~diag,
},
initial_workspace={self.f: data},
mask=self.build_mask(self.ones_mask()),
)
check_arrays(results['notnan'], ~diag)
check_arrays(results['notnull'], ~diag)
def test_isfinite(self):
data = self.randn_data(seed=10)
@@ -390,11 +363,12 @@ class FilterTestCase(BasePipelineTestCase):
data[:, 2] = inf
data[:, 4] = -inf
results = self.run_graph(
TermGraph({'isfinite': self.f.isfinite()}),
self.check_terms(
terms={'isfinite': self.f.isfinite()},
expected={'isfinite': isfinite(data)},
initial_workspace={self.f: data},
mask=self.build_mask(self.ones_mask()),
)
check_arrays(results['isfinite'], isfinite(data))
def test_all(self):
@@ -427,18 +401,19 @@ class FilterTestCase(BasePipelineTestCase):
inputs = ()
window_length = 0
results = self.run_graph(
TermGraph({
self.check_terms(
terms={
'3': All(inputs=[Input()], window_length=3),
'4': All(inputs=[Input()], window_length=4),
}),
},
expected={
'3': expected_3,
'4': expected_4,
},
initial_workspace={Input(): data},
mask=self.build_mask(ones(shape=data.shape)),
)
check_arrays(results['3'], expected_3)
check_arrays(results['4'], expected_4)
def test_any(self):
# FUN FACT: The inputs and outputs here are exactly the negation of
@@ -486,18 +461,19 @@ class FilterTestCase(BasePipelineTestCase):
inputs = ()
window_length = 0
results = self.run_graph(
TermGraph({
self.check_terms(
terms={
'3': Any(inputs=[Input()], window_length=3),
'4': Any(inputs=[Input()], window_length=4),
}),
},
expected={
'3': expected_3,
'4': expected_4,
},
initial_workspace={Input(): data},
mask=self.build_mask(ones(shape=data.shape)),
)
check_arrays(results['3'], expected_3)
check_arrays(results['4'], expected_4)
def test_at_least_N(self):
# With a window_length of K, AtLeastN should return 1
@@ -553,26 +529,27 @@ class FilterTestCase(BasePipelineTestCase):
window_length=4,
N=4)
results = self.run_graph(
TermGraph({
self.check_terms(
terms={
'AllButOne': all_but_one,
'AllButTwo': all_but_two,
'AnyEquiv': any_equiv,
'AllEquiv': all_equiv,
'Any': Any(inputs=[Input()], window_length=4),
'All': All(inputs=[Input()], window_length=4)
}),
},
expected={
'Any': expected_1,
'AnyEquiv': expected_1,
'AllButTwo': expected_2,
'AllButOne': expected_3,
'All': expected_4,
'AllEquiv': expected_4,
},
initial_workspace={Input(): data},
mask=self.build_mask(ones(shape=data.shape)),
)
check_arrays(results['Any'], expected_1)
check_arrays(results['AnyEquiv'], expected_1)
check_arrays(results['AllButTwo'], expected_2)
check_arrays(results['AllButOne'], expected_3)
check_arrays(results['All'], expected_4)
check_arrays(results['AllEquiv'], expected_4)
@parameter_space(factor_len=[2, 3, 4])
def test_window_safe(self, factor_len):
# all true data set of (days, securities)
@@ -591,19 +568,19 @@ class FilterTestCase(BasePipelineTestCase):
# sum for each column
out[:] = np_sum(filter_, axis=0)
results = self.run_graph(
TermGraph({'windowsafe': TestFactor()}),
initial_workspace={InputFilter(): data},
)
# number of days in default_shape
n = self.default_shape[0]
# shape of output array
output_shape = ((n - factor_len + 1), self.default_shape[1])
check_arrays(
results['windowsafe'],
full(output_shape, factor_len, dtype=float64)
full(output_shape, factor_len, dtype=float64)
self.check_terms(
terms={
'windowsafe': TestFactor(),
},
expected={
'windowsafe': full(output_shape, factor_len, dtype=float64),
},
initial_workspace={InputFilter(): data},
mask=self.build_mask(self.ones_mask()),
)
@parameter_space(
+58
View File
@@ -1,10 +1,14 @@
"""
Tests for zipline.pipeline.Pipeline
"""
import inspect
from unittest import TestCase
from mock import patch
from zipline.pipeline import Factor, Filter, Pipeline
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.graph import display_graph
from zipline.utils.numpy_utils import float64_dtype
@@ -137,3 +141,57 @@ class PipelineTestCase(TestCase):
"expected a value of type bool or int for argument 'overwrite'",
message,
)
def test_show_graph(self):
f = SomeFactor()
p = Pipeline(columns={'f': SomeFactor()})
# The real display_graph call shells out to GraphViz, which isn't a
# requirement, so patch it out for testing.
def mock_display_graph(g, format='svg', include_asset_exists=False):
return (g, format, include_asset_exists)
self.assertEqual(
inspect.getargspec(display_graph),
inspect.getargspec(mock_display_graph),
msg="Mock signature doesn't match signature for display_graph."
)
patch_display_graph = patch(
'zipline.pipeline.graph.display_graph',
mock_display_graph,
)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph()
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'svg')
self.assertEqual(include_asset_exists, False)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph(format='png')
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'png')
self.assertEqual(include_asset_exists, False)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph(format='jpeg')
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'jpeg')
self.assertEqual(include_asset_exists, False)
expected = (
r".*\.show_graph\(\) expected a value in "
r"\('svg', 'png', 'jpeg'\) for argument 'format', "
r"but got 'fizzbuzz' instead."
)
with self.assertRaisesRegexp(ValueError, expected):
p.show_graph(format='fizzbuzz')
+42 -81
View File
@@ -7,10 +7,7 @@ import pandas as pd
import talib
from zipline.lib.adjusted_array import AdjustedArray
from zipline.pipeline import TermGraph
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline.term import AssetExists
from zipline.pipeline.factors import (
BollingerBands,
Aroon,
@@ -20,61 +17,22 @@ from zipline.pipeline.factors import (
RateOfChangePercentage,
TrueRange,
)
from zipline.testing import ExplodingObject, parameter_space
from zipline.testing.fixtures import WithAssetFinder, ZiplineTestCase
from zipline.testing import parameter_space
from zipline.testing.fixtures import ZiplineTestCase
from zipline.testing.predicates import assert_equal
class WithTechnicalFactor(WithAssetFinder):
"""ZiplineTestCase fixture for testing technical factors.
"""
ASSET_FINDER_EQUITY_SIDS = tuple(range(5))
START_DATE = pd.Timestamp('2014-01-01', tz='utc')
@classmethod
def init_class_fixtures(cls):
super(WithTechnicalFactor, cls).init_class_fixtures()
cls.ndays = ndays = 24
cls.nassets = nassets = len(cls.ASSET_FINDER_EQUITY_SIDS)
cls.dates = dates = pd.date_range(cls.START_DATE, periods=ndays)
cls.assets = pd.Index(cls.asset_finder.sids)
cls.engine = SimplePipelineEngine(
lambda column: ExplodingObject(),
dates,
cls.asset_finder,
)
cls.asset_exists = exists = np.full((ndays, nassets), True, dtype=bool)
cls.asset_exists_masked = masked = exists.copy()
masked[:, -1] = False
def run_graph(self, graph, initial_workspace, mask_sid):
initial_workspace.setdefault(
AssetExists(),
self.asset_exists_masked if mask_sid else self.asset_exists,
)
return self.engine.compute_chunk(
graph,
self.dates,
self.assets,
initial_workspace,
)
from .base import BasePipelineTestCase
class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
@classmethod
def init_class_fixtures(cls):
super(BollingerBandsTestCase, cls).init_class_fixtures()
cls._closes = closes = (
np.arange(cls.ndays, dtype=float)[:, np.newaxis] +
np.arange(cls.nassets, dtype=float) * 100
)
cls._closes_masked = masked = closes.copy()
masked[:, -1] = np.nan
class BollingerBandsTestCase(BasePipelineTestCase):
def closes(self, masked):
return self._closes_masked if masked else self._closes
def closes(self, mask_last_sid):
data = self.arange_data(dtype=np.float64)
if mask_last_sid:
data[:, -1] = np.nan
return data
def expected(self, window_length, k, closes):
def expected_bbands(self, window_length, k, closes):
"""Compute the expected data (without adjustments) for the given
window, k, and closes array.
@@ -83,11 +41,14 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
lower_cols = []
middle_cols = []
upper_cols = []
for n in range(self.nassets):
ndates, nassets = closes.shape
for n in range(nassets):
close_col = closes[:, n]
if np.isnan(close_col).all():
# ta-lib doesn't deal well with all nans.
upper, middle, lower = [np.full(self.ndays, np.nan)] * 3
upper, middle, lower = [np.full(ndates, np.nan)] * 3
else:
upper, middle, lower = talib.BBANDS(
close_col,
@@ -112,38 +73,38 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
@parameter_space(
window_length={5, 10, 20},
k={1.5, 2, 2.5},
mask_sid={True, False},
mask_last_sid={True, False},
__fail_fast=True,
)
def test_bollinger_bands(self, window_length, k, mask_sid):
closes = self.closes(mask_sid)
result = self.run_graph(
TermGraph({
'f': BollingerBands(
window_length=window_length,
k=k,
),
}),
def test_bollinger_bands(self, window_length, k, mask_last_sid):
closes = self.closes(mask_last_sid=mask_last_sid)
mask = ~np.isnan(closes)
bbands = BollingerBands(window_length=window_length, k=k)
expected = self.expected_bbands(window_length, k, closes)
self.check_terms(
terms={
'upper': bbands.upper,
'middle': bbands.middle,
'lower': bbands.lower,
},
expected={
'upper': expected[0],
'middle': expected[1],
'lower': expected[2],
},
initial_workspace={
USEquityPricing.close: AdjustedArray(
closes,
np.full_like(closes, True, dtype=bool),
{},
np.nan,
data=closes,
mask=mask,
adjustments={},
missing_value=np.nan,
),
},
mask_sid=mask_sid,
)['f']
expected_upper, expected_middle, expected_lower = self.expected(
window_length,
k,
closes,
mask=self.build_mask(mask),
)
assert_equal(result.upper, expected_upper)
assert_equal(result.middle, expected_middle)
assert_equal(result.lower, expected_lower)
def test_bollinger_bands_output_ordering(self):
bbands = BollingerBands(window_length=5, k=2)
lower, middle, upper = bbands
@@ -185,7 +146,7 @@ class AroonTestCase(ZiplineTestCase):
assert_equal(out, expected_out)
class TestFastStochasticOscillator(WithTechnicalFactor, ZiplineTestCase):
class TestFastStochasticOscillator(ZiplineTestCase):
"""
Test the Fast Stochastic Oscillator
"""
@@ -427,7 +388,7 @@ class TestLinearWeightedMovingAverage(ZiplineTestCase):
assert_equal(out, np.array([30., 31., 32., 33., 34.]))
class TestTrueRange(WithTechnicalFactor, ZiplineTestCase):
class TestTrueRange(ZiplineTestCase):
def test_tr_basic(self):
tr = TrueRange()
+24 -5
View File
@@ -6,6 +6,7 @@ from itertools import product
from unittest import TestCase
from toolz import assoc
import pandas as pd
from zipline.assets import Asset
from zipline.errors import (
@@ -24,7 +25,7 @@ from zipline.pipeline import (
CustomFactor,
Factor,
Filter,
TermGraph,
ExecutionPlan,
)
from zipline.pipeline.data import Column, DataSet
from zipline.pipeline.data.testing import TestingDataSet
@@ -33,6 +34,7 @@ from zipline.pipeline.factors import RecarrayField
from zipline.pipeline.sentinels import NotSpecified
from zipline.pipeline.term import AssetExists, Slice
from zipline.testing import parameter_space
from zipline.testing.fixtures import WithTradingSessions, ZiplineTestCase
from zipline.testing.predicates import (
assert_equal,
assert_raises,
@@ -152,7 +154,14 @@ def to_dict(l):
return dict(zip(map(str, range(len(l))), l))
class DependencyResolutionTestCase(TestCase):
class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
TRADING_CALENDAR_STRS = ('NYSE',)
START_DATE = pd.Timestamp('2014-01-02', tz='UTC')
END_DATE = pd.Timestamp('2014-12-31', tz='UTC')
execution_plan_start = pd.Timestamp('2014-06-01', tz='UTC')
execution_plan_end = pd.Timestamp('2014-06-30', tz='UTC')
def check_dependency_order(self, ordered_terms):
seen = set()
@@ -163,6 +172,14 @@ class DependencyResolutionTestCase(TestCase):
seen.add(term)
def make_execution_plan(self, terms):
return ExecutionPlan(
terms,
self.nyse_sessions,
self.execution_plan_start,
self.execution_plan_end,
)
def test_single_factor(self):
"""
Test dependency resolution for a single factor.
@@ -182,7 +199,7 @@ class DependencyResolutionTestCase(TestCase):
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
for foobar in gen_equivalent_factors():
check_output(TermGraph(to_dict([foobar])))
check_output(self.make_execution_plan(to_dict([foobar])))
def test_single_factor_instance_args(self):
"""
@@ -190,7 +207,9 @@ class DependencyResolutionTestCase(TestCase):
the constructor.
"""
bar, buzz = SomeDataSet.bar, SomeDataSet.buzz
graph = TermGraph(to_dict([SomeFactor([bar, buzz], window_length=5)]))
factor = SomeFactor([bar, buzz], window_length=5)
graph = self.make_execution_plan(to_dict([factor]))
resolution_order = list(graph.ordered())
@@ -214,7 +233,7 @@ class DependencyResolutionTestCase(TestCase):
f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz])
graph = TermGraph(to_dict([f1, f2]))
graph = self.make_execution_plan(to_dict([f1, f2]))
resolution_order = list(graph.ordered())
# bar should only appear once.
+79
View File
@@ -0,0 +1,79 @@
"""
Tests for zipline/utils/pandas_utils.py
"""
import pandas as pd
from zipline.testing import parameter_space, ZiplineTestCase
from zipline.utils.pandas_utils import nearest_unequal_elements
class TestNearestUnequalElements(ZiplineTestCase):
@parameter_space(tz=['UTC', 'US/Eastern'], __fail_fast=True)
def test_nearest_unequal_elements(self, tz):
dts = pd.to_datetime(
['2014-01-01', '2014-01-05', '2014-01-06', '2014-01-09'],
).tz_localize(tz)
t = lambda s: None if s is None else pd.Timestamp(s, tz=tz)
for dt, before, after in (('2013-12-30', None, '2014-01-01'),
('2013-12-31', None, '2014-01-01'),
('2014-01-01', None, '2014-01-05'),
('2014-01-02', '2014-01-01', '2014-01-05'),
('2014-01-03', '2014-01-01', '2014-01-05'),
('2014-01-04', '2014-01-01', '2014-01-05'),
('2014-01-05', '2014-01-01', '2014-01-06'),
('2014-01-06', '2014-01-05', '2014-01-09'),
('2014-01-07', '2014-01-06', '2014-01-09'),
('2014-01-08', '2014-01-06', '2014-01-09'),
('2014-01-09', '2014-01-06', None),
('2014-01-10', '2014-01-09', None),
('2014-01-11', '2014-01-09', None)):
computed = nearest_unequal_elements(dts, t(dt))
expected = (t(before), t(after))
self.assertEqual(computed, expected)
@parameter_space(tz=['UTC', 'US/Eastern'], __fail_fast=True)
def test_nearest_unequal_elements_short_dts(self, tz):
# Length 1.
dts = pd.to_datetime(['2014-01-01']).tz_localize(tz)
t = lambda s: None if s is None else pd.Timestamp(s, tz=tz)
for dt, before, after in (('2013-12-31', None, '2014-01-01'),
('2014-01-01', None, None),
('2014-01-02', '2014-01-01', None)):
computed = nearest_unequal_elements(dts, t(dt))
expected = (t(before), t(after))
self.assertEqual(computed, expected)
# Length 0
dts = pd.to_datetime([]).tz_localize(tz)
for dt, before, after in (('2013-12-31', None, None),
('2014-01-01', None, None),
('2014-01-02', None, None)):
computed = nearest_unequal_elements(dts, t(dt))
expected = (t(before), t(after))
self.assertEqual(computed, expected)
def test_nearest_unequal_bad_input(self):
with self.assertRaises(ValueError) as e:
nearest_unequal_elements(
pd.to_datetime(['2014', '2014']),
pd.Timestamp('2014'),
)
self.assertEqual(str(e.exception), 'dts must be unique')
with self.assertRaises(ValueError) as e:
nearest_unequal_elements(
pd.to_datetime(['2014', '2013']),
pd.Timestamp('2014'),
)
self.assertEqual(
str(e.exception),
'dts must be sorted in increasing order',
)
+5 -1
View File
@@ -262,7 +262,11 @@ class PreprocessTestCase(TestCase):
expected_message = (
"{qualname}() expected a value in {set_!r}"
" for argument 'a', but got 'c' instead."
).format(set_=set_, qualname=qualname(f))
).format(
# We special-case set to show a tuple instead of the set repr.
set_=tuple(sorted(set_)),
qualname=qualname(f),
)
self.assertEqual(e.exception.args[0], expected_message)
def test_expect_dtypes(self):
+2 -1
View File
@@ -59,6 +59,7 @@ from .asset_db_schema import (
)
from zipline.utils.control_flow import invert
from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import as_column
from zipline.utils.sqlite_utils import group_into_chunks
log = Logger('assets.py')
@@ -1096,7 +1097,7 @@ class AssetFinder(object):
self._asset_lifetimes = self._compute_asset_lifetimes()
lifetimes = self._asset_lifetimes
raw_dates = dates.asi8[:, None]
raw_dates = as_column(dates.asi8)
if include_start_date:
mask = lifetimes.start <= raw_dates
else:
+24
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from zipline.utils.memoize import lazyval
@@ -583,6 +584,29 @@ class NoFurtherDataError(ZiplineError):
# that can be usefully templated.
msg = '{msg}'
@classmethod
def from_lookback_window(cls,
initial_message,
first_date,
lookback_start,
lookback_length):
return cls(
msg=dedent(
"""
{initial_message}
lookback window started at {lookback_start}
earliest known date was {first_date}
{lookback_length} extra rows of data were required
"""
).format(
initial_message=initial_message,
first_date=first_date,
lookback_start=lookback_start,
lookback_length=lookback_length,
)
)
class UnsupportedDatetimeFormat(ZiplineError):
"""
+2 -1
View File
@@ -6,7 +6,7 @@ from .engine import SimplePipelineEngine
from .factors import Factor, CustomFactor
from .filters import Filter, CustomFilter
from .term import Term
from .graph import TermGraph
from .graph import ExecutionPlan, TermGraph
from .pipeline import Pipeline
from .loaders import USEquityPricingLoader
@@ -56,6 +56,7 @@ __all__ = (
'CustomFilter',
'CustomClassifier',
'engine_from_files',
'ExecutionPlan',
'Factor',
'Filter',
'Pipeline',
@@ -14,6 +14,7 @@ from zipline.pipeline.sentinels import NotSpecified
from zipline.pipeline.term import ComputableTerm
from zipline.utils.compat import unicode
from zipline.utils.input_validation import expect_types
from zipline.utils.memoize import classlazyval
from zipline.utils.numpy_utils import (
categorical_dtype,
int64_dtype,
@@ -23,6 +24,7 @@ from zipline.utils.numpy_utils import (
from ..filters import ArrayPredicate, NotNullFilter, NullFilter, NumExprFilter
from ..mixins import (
CustomTermMixin,
DownsampledMixin,
LatestMixin,
PositiveWindowLengthMixin,
RestrictedDTypeMixin,
@@ -301,6 +303,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
raise AssertionError("Expected a LabelArray, got %s." % type(data))
return data.as_categorical()
@classlazyval
def _downsampled_type(self):
return DownsampledMixin.make_downsampled_type(Classifier)
class Everything(Classifier):
"""
-1
View File
@@ -112,7 +112,6 @@ class BoundColumn(LoadableTerm):
The name of this column.
"""
mask = AssetExists()
inputs = ()
window_safe = True
def __new__(cls, dtype, missing_value, dataset, name):
+61
View File
@@ -0,0 +1,61 @@
"""
Helpers for downsampling code.
"""
from operator import attrgetter
from zipline.utils.input_validation import expect_element
from zipline.utils.numpy_utils import changed_locations
from zipline.utils.sharedoc import (
templated_docstring,
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
)
_dt_to_period = {
'year_start': attrgetter('year'),
'quarter_start': attrgetter('quarter'),
'month_start': attrgetter('month'),
'week_start': attrgetter('week'),
}
SUPPORTED_DOWNSAMPLE_FREQUENCIES = frozenset(_dt_to_period)
expect_downsample_frequency = expect_element(
frequency=SUPPORTED_DOWNSAMPLE_FREQUENCIES,
)
@expect_downsample_frequency
@templated_docstring(frequency=PIPELINE_DOWNSAMPLING_FREQUENCY_DOC)
def select_sampling_indices(dates, frequency):
"""
Choose entries from ``dates`` to use for downsampling at ``frequency``.
Parameters
----------
dates : pd.DatetimeIndex
Dates from which to select sample choices.
{frequency}
Returns
-------
indices : np.array[int64]
An array condtaining indices of dates on which samples should be taken.
The resulting index will always include 0 as a sample index, and it
will include the first date of each subsequent year/quarter/month/week,
as determined by ``frequency``.
Notes
-----
This function assumes that ``dates`` does not have large gaps.
In particular, it assumes that the maximum distance between any two entries
in ``dates`` is never greater than a year, which we rely on because we use
``np.diff(dates.<frequency>)`` to find dates where the sampling
period has changed.
"""
return changed_locations(
_dt_to_period[frequency](dates),
include_first=True
)
+25 -11
View File
@@ -18,10 +18,14 @@ from toolz.curried.operator import getitem
from zipline.lib.adjusted_array import ensure_adjusted_array, ensure_ndarray
from zipline.errors import NoFurtherDataError
from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis
from zipline.utils.numpy_utils import (
as_column,
repeat_first_axis,
repeat_last_axis,
)
from zipline.utils.pandas_utils import explode
from .term import AssetExists, LoadableTerm
from .term import AssetExists, InputDates, LoadableTerm
class PipelineEngine(with_metaclass(ABCMeta)):
@@ -98,6 +102,7 @@ class SimplePipelineEngine(object):
'_calendar',
'_finder',
'_root_mask_term',
'_root_mask_dates_term',
'__weakref__',
)
@@ -105,7 +110,9 @@ class SimplePipelineEngine(object):
self._get_loader = get_loader
self._calendar = calendar
self._finder = asset_finder
self._root_mask_term = AssetExists()
self._root_mask_dates_term = InputDates()
def run_pipeline(self, pipeline, start_date, end_date):
"""
@@ -161,7 +168,13 @@ class SimplePipelineEngine(object):
)
screen_name = uuid4().hex
graph = pipeline.to_graph(screen_name, self._root_mask_term)
graph = pipeline.to_execution_plan(
screen_name,
self._root_mask_term,
self._calendar,
start_date,
end_date,
)
extra_rows = graph.extra_rows[self._root_mask_term]
root_mask = self._compute_root_mask(start_date, end_date, extra_rows)
dates, assets, root_mask_values = explode(root_mask)
@@ -170,7 +183,10 @@ class SimplePipelineEngine(object):
graph,
dates,
assets,
initial_workspace={self._root_mask_term: root_mask_values},
initial_workspace={
self._root_mask_term: root_mask_values,
self._root_mask_dates_term: as_column(dates.values)
},
)
return self._to_narrow(
@@ -210,13 +226,11 @@ class SimplePipelineEngine(object):
finder = self._finder
start_idx, end_idx = self._calendar.slice_locs(start_date, end_date)
if start_idx < extra_rows:
raise NoFurtherDataError(
msg="Insufficient data to compute Pipeline mask: "
"start date was %s, "
"earliest known date was %s, "
"and %d extra rows were requested." % (
start_date, calendar[0], extra_rows,
),
raise NoFurtherDataError.from_lookback_window(
initial_message="Insufficient data to compute Pipeline:",
first_date=calendar[0],
lookback_start=start_date,
lookback_length=extra_rows,
)
# Build lifetimes matrix reaching back to `extra_rows` days before
+9 -1
View File
@@ -33,6 +33,7 @@ from zipline.pipeline.filters import (
)
from zipline.pipeline.mixins import (
CustomTermMixin,
DownsampledMixin,
LatestMixin,
PositiveWindowLengthMixin,
RestrictedDTypeMixin,
@@ -43,6 +44,7 @@ from zipline.pipeline.term import ComputableTerm, Term
from zipline.utils.functional import with_doc, with_name
from zipline.utils.input_validation import expect_types
from zipline.utils.math_utils import nanmean, nanstd
from zipline.utils.memoize import classlazyval
from zipline.utils.numpy_utils import (
bool_dtype,
categorical_dtype,
@@ -1071,6 +1073,10 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
"""
return (-inf < self) & (self < inf)
@classlazyval
def _downsampled_type(self):
return DownsampledMixin.make_downsampled_type(Factor)
class NumExprFactor(NumericalExpression, Factor):
"""
@@ -1468,7 +1474,9 @@ class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
class RecarrayField(SingleInputMixin, Factor):
"""
A single field from a multi-output factor.
"""
def __new__(cls, factor, attribute):
return super(RecarrayField, cls).__new__(
cls,
+6
View File
@@ -25,6 +25,7 @@ from zipline.pipeline.expression import (
)
from zipline.pipeline.mixins import (
CustomTermMixin,
DownsampledMixin,
LatestMixin,
PositiveWindowLengthMixin,
RestrictedDTypeMixin,
@@ -32,6 +33,7 @@ from zipline.pipeline.mixins import (
)
from zipline.pipeline.term import ComputableTerm, Term
from zipline.utils.input_validation import expect_types
from zipline.utils.memoize import classlazyval
from zipline.utils.numpy_utils import bool_dtype, repeat_first_axis
@@ -201,6 +203,10 @@ class Filter(RestrictedDTypeMixin, ComputableTerm):
)
return retval
@classlazyval
def _downsampled_type(self):
return DownsampledMixin.make_downsampled_type(Filter)
class NumExprFilter(NumericalExpression, Filter):
"""
+163 -79
View File
@@ -5,7 +5,7 @@ from networkx import (
DiGraph,
topological_sort,
)
from six import itervalues, iteritems
from six import iteritems, itervalues
from zipline.utils.memoize import lazyval
from zipline.pipeline.visualize import display_graph
@@ -18,7 +18,112 @@ class CyclicDependency(Exception):
class TermGraph(DiGraph):
"""
Graph represention of Pipeline Term dependencies.
An abstract representation of Pipeline Term dependencies.
This class does not keep any additional metadata about any term relations
other than dependency ordering. As such it is only useful in contexts
where you care exclusively about order properties (for example, when
drawing visualizations of execution order).
Parameters
----------
terms : dict
A dict mapping names to final output terms.
Attributes
----------
outputs
Methods
-------
ordered()
Return a topologically-sorted iterator over the terms in self.
See Also
--------
ExecutionPlan
"""
def __init__(self, terms):
super(TermGraph, self).__init__()
self._frozen = False
parents = set()
for term in itervalues(terms):
self._add_to_graph(term, parents)
# No parents should be left between top-level terms.
assert not parents
self._outputs = terms
self._ordered = topological_sort(self)
# Mark that no more terms should be added to the graph.
self._frozen = True
def _add_to_graph(self, term, parents):
"""
Add a term and all its children to ``graph``.
``parents`` is the set of all the parents of ``term` that we've added
so far. It is only used to detect dependency cycles.
"""
if self._frozen:
raise ValueError(
"Can't mutate %s after construction." % type(self).__name__
)
# If we've seen this node already as a parent of the current traversal,
# it means we have an unsatisifiable dependency. This should only be
# possible if the term's inputs are mutated after construction.
if term in parents:
raise CyclicDependency(term)
parents.add(term)
self.add_node(term)
for dependency in term.dependencies:
self._add_to_graph(dependency, parents)
self.add_edge(dependency, term)
parents.remove(term)
@property
def outputs(self):
"""
Dict mapping names to designated output terms.
"""
return self._outputs
def ordered(self):
"""
Return a topologically-sorted iterator over the terms in `self`.
"""
return iter(self._ordered)
@lazyval
def loadable_terms(self):
return tuple(term for term in self if isinstance(term, LoadableTerm))
@lazyval
def jpeg(self):
return display_graph(self, 'jpeg')
@lazyval
def png(self):
return display_graph(self, 'png')
@lazyval
def svg(self):
return display_graph(self, 'svg')
def _repr_png_(self):
return self.png.data
class ExecutionPlan(TermGraph):
"""
Graph represention of Pipeline Term dependencies that includes metadata
about extra rows required to perform computations.
Each node in the graph has an `extra_rows` attribute, indicating how many,
if any, extra rows we should compute for the node. Extra rows are most
@@ -30,6 +135,13 @@ class TermGraph(DiGraph):
----------
terms : dict
A dict mapping names to final output terms.
all_dates : pd.DatetimeIndex
An index of all known trading days for which ``terms`` will be
computed.
start_date : pd.Timestamp
The first date for which output is requested for ``terms``.
end_date : pd.Timestamp
The last date for which output is requested for ``terms``.
Attributes
----------
@@ -42,21 +154,58 @@ class TermGraph(DiGraph):
ordered()
Return a topologically-sorted iterator over the terms in self.
"""
def __init__(self, terms):
super(TermGraph, self).__init__(self)
def __init__(self,
terms,
all_dates,
start_date,
end_date,
min_extra_rows=0):
super(ExecutionPlan, self).__init__(terms)
self._frozen = False
parents = set()
for term in itervalues(terms):
self._add_to_graph(term, parents, extra_rows=0)
# No parents should be left between top-level terms.
assert not parents
for term in terms.values():
self.set_extra_rows(
term,
all_dates,
start_date,
end_date,
min_extra_rows=min_extra_rows,
)
self._outputs = terms
self._ordered = topological_sort(self)
def set_extra_rows(self,
term,
all_dates,
start_date,
end_date,
min_extra_rows):
"""
Compute ``extra_rows`` for transitive dependencies of ``root_terms``
"""
# A term can require that additional extra rows beyond the minimum be
# computed. This is most often used with downsampled terms, which need
# to ensure that the first date is a computation date.
extra_rows_for_term = term.compute_extra_rows(
all_dates,
start_date,
end_date,
min_extra_rows,
)
if extra_rows_for_term < min_extra_rows:
raise ValueError(
"term %s requested fewer rows than the minimum of %d" % (
term, min_extra_rows,
)
)
# Mark that no more terms should be added to the graph.
self._frozen = True
self._ensure_extra_rows(term, extra_rows_for_term)
for dependency, additional_extra_rows in term.dependencies.items():
self.set_extra_rows(
dependency,
all_dates,
start_date,
end_date,
min_extra_rows=extra_rows_for_term + additional_extra_rows,
)
@lazyval
def offset(self):
@@ -138,9 +287,6 @@ class TermGraph(DiGraph):
"""
A dict mapping `term` -> `# of extra rows to load/compute of `term`.
This is always the maximum number of extra **input** rows required by
any Filter/Factor for which `term` is an input.
Notes
----
This value depends on the other terms in the graph that require `term`
@@ -175,71 +321,9 @@ class TermGraph(DiGraph):
for term, attrs in iteritems(self.node)
}
@property
def outputs(self):
"""
Dict mapping names to designated output terms.
"""
return self._outputs
def ordered(self):
"""
Return a topologically-sorted iterator over the terms in `self`.
"""
return iter(self._ordered)
@lazyval
def loadable_terms(self):
return tuple(term for term in self if isinstance(term, LoadableTerm))
def _add_to_graph(self, term, parents, extra_rows):
"""
Add `term` and all its inputs to the graph.
"""
if self._frozen:
raise ValueError("Can't mutate `TermGraph` after construction.")
# If we've seen this node already as a parent of the current traversal,
# it means we have an unsatisifiable dependency. This should only be
# possible if the term's inputs are mutated after construction.
if term in parents:
raise CyclicDependency(term)
parents.add(term)
# Idempotent if term is already in the graph.
self.add_node(term)
# Make sure we're going to compute at least `extra_rows` of `term`.
self._ensure_extra_rows(term, extra_rows)
# Recursively add dependencies.
for dependency, additional_extra_rows in term.dependencies.items():
self._add_to_graph(
dependency,
parents,
extra_rows=extra_rows + additional_extra_rows,
)
self.add_edge(dependency, term)
parents.remove(term)
def _ensure_extra_rows(self, term, N):
"""
Ensure that we're going to compute at least N extra rows of `term`.
"""
attrs = self.node[term]
attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0))
@lazyval
def jpeg(self):
return display_graph(self, 'jpeg')
@lazyval
def png(self):
return display_graph(self, 'png')
@lazyval
def svg(self):
return display_graph(self, 'svg')
def _repr_png_(self):
return self.png.data
+2 -1
View File
@@ -15,6 +15,7 @@ from pandas import (
)
from zipline.lib.adjusted_array import AdjustedArray
from zipline.lib.adjustment import make_adjustment_from_labels
from zipline.utils.numpy_utils import as_column
from zipline.utils.pandas_utils import sort_values
from .base import PipelineLoader
@@ -169,7 +170,7 @@ class DataFrameLoader(PipelineLoader):
# Pull out requested columns/rows from our baseline data.
data=self.baseline[ix_(date_indexer, assets_indexer)],
# Mask out requested columns/rows that didnt match.
mask=(good_assets & good_dates[:, None]) & mask,
mask=(good_assets & as_column(good_dates)) & mask,
adjustments=self.format_adjustments(dates, assets),
missing_value=column.missing_value,
),
+243 -1
View File
@@ -1,16 +1,36 @@
"""
Mixins classes for use with Filters and Factors.
"""
from textwrap import dedent
from numpy import (
array,
full,
recarray,
vstack,
)
from pandas import NaT as pd_NaT
from zipline.errors import (
WindowLengthNotPositive,
UnsupportedDataType,
NoFurtherDataError,
)
from zipline.utils.control_flow import nullctx
from zipline.errors import WindowLengthNotPositive, UnsupportedDataType
from zipline.utils.input_validation import expect_types
from zipline.utils.sharedoc import (
format_docstring,
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
)
from zipline.utils.pandas_utils import nearest_unequal_elements
from .downsample_helpers import (
select_sampling_indices,
expect_downsample_frequency,
)
from .sentinels import NotSpecified
from .term import Term
class PositiveWindowLengthMixin(object):
@@ -218,3 +238,225 @@ class LatestMixin(SingleInputMixin):
actual=self.inputs[0].dtype,
)
)
class DownsampledMixin(StandardOutputs):
"""
Mixin for behavior shared by Downsampled{Factor,Filter,Classifier}
A downsampled term is a wrapper around the "real" term that performs actual
computation. The downsampler is responsible for calling the real term's
`compute` method at selected intervals and forward-filling the computed
values.
Downsampling is not currently supported for terms with multiple outputs.
"""
# There's no reason to take a window of a downsampled term. The whole
# point is that you're re-using the same result multiple times.
window_safe = False
@expect_types(term=Term)
@expect_downsample_frequency
def __new__(cls, term, frequency):
return super(DownsampledMixin, cls).__new__(
cls,
inputs=term.inputs,
outputs=term.outputs,
window_length=term.window_length,
mask=term.mask,
frequency=frequency,
wrapped_term=term,
dtype=term.dtype,
missing_value=term.missing_value,
ndim=term.ndim,
)
def _init(self, frequency, wrapped_term, *args, **kwargs):
self._frequency = frequency
self._wrapped_term = wrapped_term
return super(DownsampledMixin, self)._init(*args, **kwargs)
@classmethod
def _static_identity(cls, frequency, wrapped_term, *args, **kwargs):
return (
super(DownsampledMixin, cls)._static_identity(*args, **kwargs),
frequency,
wrapped_term,
)
def compute_extra_rows(self,
all_dates,
start_date,
end_date,
min_extra_rows):
"""
Ensure that min_extra_rows pushes us back to a computation date.
Parameters
----------
all_dates : pd.DatetimeIndex
The trading sessions against which ``self`` will be computed.
start_date : pd.Timestamp
The first date for which final output is requested.
end_date : pd.Timestamp
The last date for which final output is requested.
min_extra_rows : int
The minimum number of extra rows required of ``self``, as
determined by other terms that depend on ``self``.
Returns
-------
extra_rows : int
The number of extra rows to compute. This will be the minimum
number of rows required to make our computed start_date fall on a
recomputation date.
"""
try:
current_start_pos = all_dates.get_loc(start_date) - min_extra_rows
if current_start_pos < 0:
raise NoFurtherDataError(
initial_message="Insufficient data to compute Pipeline:",
first_date=all_dates[0],
lookback_start=start_date,
lookback_length=min_extra_rows,
)
except KeyError:
before, after = nearest_unequal_elements(all_dates, start_date)
raise ValueError(
"Pipeline start_date {start_date} is not in calendar.\n"
"Latest date before start_date is {before}.\n"
"Earliest date after start_date is {after}.".format(
start_date=start_date,
before=before,
after=after,
)
)
# Our possible target dates are all the dates on or before the current
# starting position.
# TODO: Consider bounding this below by self.window_length
candidates = all_dates[:current_start_pos + 1]
# Choose the latest date in the candidates that is the start of a new
# period at our frequency.
choices = select_sampling_indices(candidates, self._frequency)
# If we have choices, the last choice is the first date if the
# period containing current_start_date. Choose it.
new_start_date = candidates[choices[-1]]
# Add the difference between the new and old start dates to get the
# number of rows for the new start_date.
new_start_pos = all_dates.get_loc(new_start_date)
assert new_start_pos <= current_start_pos, \
"Computed negative extra rows!"
return min_extra_rows + (current_start_pos - new_start_pos)
def _compute(self, inputs, dates, assets, mask):
"""
Compute by delegating to self._wrapped_term._compute on sample dates.
On non-sample dates, forward-fill from previously-computed samples.
"""
to_sample = dates[select_sampling_indices(dates, self._frequency)]
assert to_sample[0] == dates[0], \
"Misaligned sampling dates in %s." % type(self).__name__
real_compute = self._wrapped_term._compute
# Inputs will contain different kinds of values depending on whether or
# not we're a windowed computation.
# If we're windowed, then `inputs` is a list of iterators of ndarrays.
# If we're not windowed, then `inputs` is just a list of ndarrays.
# There are two things we care about doing with the input:
# 1. Preparing an input to be passed to our wrapped term.
# 2. Skipping an input if we're going to use an already-computed row.
# We perform these actions differently based on the expected kind of
# input, and we encapsulate these actions with closures so that we
# don't clutter the code below with lots of branching.
if self.windowed:
# If we're windowed, inputs are stateful AdjustedArrays. We don't
# need to do any preparation before forwarding to real_compute, but
# we need to call `next` on them if we want to skip an iteration.
def prepare_inputs():
return inputs
def skip_this_input():
for w in inputs:
next(w)
else:
# If we're not windowed, inputs are just ndarrays. We need to
# slice out a single row when forwarding to real_compute, but we
# don't need to do anything to skip an input.
def prepare_inputs():
# i is the loop iteration variable below.
return [a[[i]] for a in inputs]
def skip_this_input():
pass
results = []
samples = iter(to_sample)
next_sample = next(samples)
for i, compute_date in enumerate(dates):
if next_sample == compute_date:
results.append(
real_compute(
prepare_inputs(),
dates[i:i + 1],
assets,
mask[i:i + 1],
)
)
try:
next_sample = next(samples)
except StopIteration:
# No more samples to take. Set next_sample to Nat, which
# compares False with any other datetime.
next_sample = pd_NaT
else:
skip_this_input()
# Copy results from previous sample period.
results.append(results[-1])
# We should have exhausted our sample dates.
try:
next_sample = next(samples)
except StopIteration:
pass
else:
raise AssertionError("Unconsumed sample date: %s" % next_sample)
# Concatenate stored results.
return vstack(results)
@classmethod
def make_downsampled_type(cls, other_base):
"""
Factory for making Downsampled{Filter,Factor,Classifier}.
"""
docstring = dedent(
"""
A {t} that defers to another {t} at lower-than-daily frequency.
Parameters
----------
term : {t}
{{frequency}}
"""
).format(t=other_base.__name__)
doc = format_docstring(
owner_name=other_base.__name__,
docstring=docstring,
formatters={'frequency': PIPELINE_DOWNSAMPLING_FREQUENCY_DOC},
)
return type(
'Downsampled' + other_base.__name__,
(cls, other_base,),
{'__doc__': doc,
'__module__': other_base.__module__},
)
+51 -9
View File
@@ -1,10 +1,14 @@
from zipline.errors import UnsupportedPipelineOutput
from zipline.utils.input_validation import expect_types, optional
from zipline.utils.input_validation import (
expect_element,
expect_types,
optional,
)
from .term import AssetExists, ComputableTerm, Term
from .graph import ExecutionPlan, TermGraph
from .filters import Filter
from .graph import TermGraph
from .term import AssetExists, ComputableTerm, Term
class Pipeline(object):
@@ -148,9 +152,39 @@ class Pipeline(object):
)
self._screen = screen
def to_graph(self, screen_name, default_screen):
def to_execution_plan(self,
screen_name,
default_screen,
all_dates,
start_date,
end_date):
"""
Compile into a TermGraph.
Compile into an ExecutionPlan.
Parameters
----------
screen_name : str
Name to supply for self.screen.
default_screen : zipline.pipeline.term.Term
Term to use as a screen if self.screen is None.
all_dates : pd.DatetimeIndex
A calendar of dates to use to calculate starts and ends for each
term.
start_date : pd.Timestamp
The first date of requested output.
end_date : pd.Timestamp
The last date of requested output.
"""
return ExecutionPlan(
self._prepare_graph_terms(screen_name, default_screen),
all_dates,
start_date,
end_date,
)
def to_simple_graph(self, screen_name, default_screen):
"""
Compile into a simple TermGraph with no extra row metadata.
Parameters
----------
@@ -159,14 +193,20 @@ class Pipeline(object):
default_screen : zipline.pipeline.term.Term
Term to use as a screen if self.screen is None.
"""
return TermGraph(
self._prepare_graph_terms(screen_name, default_screen)
)
def _prepare_graph_terms(self, screen_name, default_screen):
"""Helper for to_graph and to_execution_plan."""
columns = self.columns.copy()
screen = self.screen
if screen is None:
screen = default_screen
columns[screen_name] = screen
return columns
return TermGraph(columns)
@expect_element(format=('svg', 'png', 'jpeg'))
def show_graph(self, format='svg'):
"""
Render this Pipeline as a DAG.
@@ -176,7 +216,7 @@ class Pipeline(object):
format : {'svg', 'png', 'jpeg'}
Image format to render with. Default is 'svg'.
"""
g = self.to_graph('', AssetExists())
g = self.to_simple_graph('', AssetExists())
if format == 'svg':
return g.svg
elif format == 'png':
@@ -184,7 +224,9 @@ class Pipeline(object):
elif format == 'jpeg':
return g.jpeg
else:
raise ValueError("Unknown graph format %r." % format)
# We should never get here because of the expect_element decorator
# above.
raise AssertionError("Unknown graph format %r." % format)
@staticmethod
def validate_column(column_name, term):
+104 -3
View File
@@ -34,10 +34,15 @@ from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import (
bool_dtype,
categorical_dtype,
datetime64ns_dtype,
default_missing_value_for_dtype,
)
from zipline.utils.sharedoc import (
templated_docstring,
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
)
from .mixins import SingleInputMixin
from .downsample_helpers import expect_downsample_frequency
from .sentinels import NotSpecified
@@ -279,6 +284,39 @@ class Term(with_metaclass(ABCMeta, object)):
# call super().
self._subclass_called_super_validate = True
def compute_extra_rows(self,
all_dates,
start_date,
end_date,
min_extra_rows):
"""
Calculate the number of extra rows needed to compute ``self``.
Must return at least ``min_extra_rows``, and the default implementation
is to just return ``min_extra_rows``. This is overridden by
downsampled terms to ensure that the first date computed is a
recomputation date.
Parameters
----------
all_dates : pd.DatetimeIndex
The trading sessions against which ``self`` will be computed.
start_date : pd.Timestamp
The first date for which final output is requested.
end_date : pd.Timestamp
The last date for which final output is requested.
min_extra_rows : int
The minimum number of extra rows required of ``self``, as
determined by other terms that depend on ``self``.
Returns
-------
extra_rows : int
The number of extra rows to compute. Must be at least
``min_extra_rows``.
"""
return min_extra_rows
@abstractproperty
def inputs(self):
"""
@@ -320,6 +358,9 @@ class AssetExists(Term):
every asset on every date. We don't subclass Filter, however, because
`AssetExists` is computed directly by the PipelineEngine.
This term is guaranteed to be available as an input for any term computed
by SimplePipelineEngine.run_pipeline().
See Also
--------
zipline.assets.AssetFinder.lifetimes
@@ -334,6 +375,38 @@ class AssetExists(Term):
def __repr__(self):
return "AssetExists()"
def _compute(self, today, assets, out):
raise NotImplementedError(
"AssetExists cannot be computed directly."
" Check your PipelineEngine configuration."
)
class InputDates(Term):
"""
1-Dimensional term providing date labels for other term inputs.
This term is guaranteed to be available as an input for any term computed
by SimplePipelineEngine.run_pipeline().
"""
ndim = 1
dataset = None
dtype = datetime64ns_dtype
inputs = ()
dependencies = {}
mask = None
windowed = False
window_safe = True
def __repr__(self):
return "InputDates()"
def _compute(self, today, assets, out):
raise NotImplementedError(
"InputDates cannot be computed directly."
" Check your PipelineEngine configuration."
)
class LoadableTerm(Term):
"""
@@ -342,6 +415,7 @@ class LoadableTerm(Term):
This is the base class for :class:`zipline.pipeline.data.BoundColumn`.
"""
windowed = False
inputs = ()
@lazyval
def dependencies(self):
@@ -353,7 +427,7 @@ class ComputableTerm(Term):
A Term that should be computed from a tuple of inputs.
This is the base class for :class:`zipline.pipeline.Factor`,
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Factor`.
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Classifier`.
"""
inputs = NotSpecified
outputs = NotSpecified
@@ -516,6 +590,27 @@ class ComputableTerm(Term):
"""
return data
def _downsampled_type(self):
"""
The expression type to return from self.downsample().
"""
raise NotImplementedError(
"downsampling is not yet implemented "
"for instances of %s." % type(self).__name__
)
@expect_downsample_frequency
@templated_docstring(frequency=PIPELINE_DOWNSAMPLING_FREQUENCY_DOC)
def downsample(self, frequency):
"""
Make a term that computes from ``self`` at lower-than-daily frequency.
Parameters
----------
{frequency}
"""
return self._downsampled_type(term=self, frequency=frequency)
def __repr__(self):
return (
"{type}({inputs}, window_length={window_length})"
@@ -526,7 +621,7 @@ class ComputableTerm(Term):
)
class Slice(ComputableTerm, SingleInputMixin):
class Slice(ComputableTerm):
"""
Term for extracting a single column of a another term's output.
@@ -582,6 +677,12 @@ class Slice(ComputableTerm, SingleInputMixin):
# column.
return windows[0][:, [asset_column]]
@property
def _downsampled_type(self):
raise NotImplementedError(
'downsampling of slices is not yet supported'
)
def validate_dtype(termname, dtype, missing_value):
"""
-1
View File
@@ -30,7 +30,6 @@ from .core import ( # noqa
make_test_handler,
make_trade_data_for_asset_info,
parameter_space,
parameter_space,
patch_os_environment,
patch_read_csv,
permute_rows,
+6 -4
View File
@@ -47,9 +47,11 @@ from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline.factors import CustomFactor
from zipline.pipeline.loaders.testing import make_seeded_random_loader
from zipline.utils import security_list
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.sentinel import sentinel
from zipline.utils.calendars import get_calendar
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.numpy_utils import as_column
from zipline.utils.sentinel import sentinel
import numpy as np
from numpy import float64
@@ -308,11 +310,11 @@ def make_trade_data_for_asset_info(dates,
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
price_date_deltas = (np.arange(len(dates), dtype=float64) *
price_step_by_date)
prices = (price_sid_deltas + price_date_deltas[:, None]) + price_start
prices = (price_sid_deltas + as_column(price_date_deltas)) + price_start
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
volume_date_deltas = np.arange(len(dates)) * volume_step_by_date
volumes = (volume_sid_deltas + volume_date_deltas[:, None]) + volume_start
volumes = volume_sid_deltas + as_column(volume_date_deltas) + volume_start
for j, sid in enumerate(sids):
start_date, end_date = asset_info.loc[sid, ['start_date', 'end_date']]
+8 -3
View File
@@ -235,6 +235,9 @@ class WithDefaultDateBounds(object):
ZiplineTestCase mixin which makes it possible to synchronize date bounds
across fixtures.
This fixture should always be the last fixture in bases of any fixture or
test case that uses it.
Attributes
----------
START_DATE : datetime
@@ -420,7 +423,9 @@ class WithTradingCalendars(object):
cls.trading_calendars[exchange] = get_calendar(cal_str)
class WithTradingEnvironment(WithAssetFinder, WithTradingCalendars):
class WithTradingEnvironment(WithAssetFinder,
WithTradingCalendars,
WithDefaultDateBounds):
"""
ZiplineTestCase mixin providing cls.env as a class-level fixture.
@@ -527,7 +532,7 @@ class WithSimParams(WithTradingEnvironment):
cls.sim_params = cls.make_simparams()
class WithTradingSessions(WithTradingCalendars):
class WithTradingSessions(WithTradingCalendars, WithDefaultDateBounds):
"""
ZiplineTestCase mixin providing cls.trading_days, cls.all_trading_sessions
as a class-level fixture.
@@ -1212,7 +1217,7 @@ class WithSeededRandomPipelineEngine(WithTradingSessions, WithAssetFinder):
if start_date not in self.trading_days:
raise AssertionError("Start date not in calendar: %s" % start_date)
if end_date not in self.trading_days:
raise AssertionError("Start date not in calendar: %s" % start_date)
raise AssertionError("End date not in calendar: %s" % end_date)
return self.seeded_random_engine.run_pipeline(
pipeline,
start_date,
+8 -1
View File
@@ -483,10 +483,17 @@ def expect_element(*_pos, **named):
raise TypeError("expect_element() only takes keyword arguments.")
def _expect_element(collection):
if isinstance(collection, (set, frozenset)):
# Special case the error message for set and frozen set to make it
# less verbose.
collection_for_error_message = tuple(sorted(collection))
else:
collection_for_error_message = collection
template = (
"%(funcname)s() expected a value in {collection} "
"for argument '%(argname)s', but got %(actual)s instead."
).format(collection=collection)
).format(collection=collection_for_error_message)
return make_check(
ValueError,
template,
+66
View File
@@ -11,8 +11,11 @@ from numpy import (
broadcast,
busday_count,
datetime64,
diff,
dtype,
empty,
flatnonzero,
hstack,
nan,
vectorize,
where
@@ -364,3 +367,66 @@ def vectorized_is_element(array, choices):
Array indicating whether each element of ``array`` was in ``choices``.
"""
return vectorize(choices.__contains__, otypes=[bool])(array)
def as_column(a):
"""
Convert an array of shape (N,) into an array of shape (N, 1).
This is equivalent to `a[:, np.newaxis]`.
Parameters
----------
a : np.ndarray
Example
-------
>>> import numpy as np
>>> a = np.arange(5)
>>> a
array([0, 1, 2, 3, 4])
>>> as_column(a)
array([[0],
[1],
[2],
[3],
[4]])
>>> as_column(a).shape
(5, 1)
"""
if a.ndim != 1:
raise ValueError(
"as_column expected an 1-dimensional array, "
"but got an array of shape %s" % a.shape
)
return a[:, None]
def changed_locations(a, include_first):
"""
Compute indices of values in ``a`` that differ from the previous value.
Parameters
----------
a : np.ndarray
The array on which to indices of change.
include_first : bool
Whether or not to consider the first index of the array as "changed".
Example
-------
>>> import numpy as np
>>> changed_locations(np.array([0, 0, 5, 5, 1, 1]), include_first=False)
array([2, 4])
>>> changed_locations(np.array([0, 0, 5, 5, 1, 1]), include_first=True)
array([0, 2, 4])
"""
if a.ndim > 1:
raise ValueError("indices_of_changed_values only supports 1D arrays.")
indices = flatnonzero(diff(a)) + 1
if not include_first:
return indices
return hstack([[0], indices])
+50
View File
@@ -96,3 +96,53 @@ def mask_between_time(dts, start, end, include_start=True, include_end=True):
left_op(start_micros, time_micros),
right_op(time_micros, end_micros),
)
def nearest_unequal_elements(dts, dt):
"""
Find values in ``dts`` closest but not equal to ``dt``.
Returns a pair of (last_before, first_after).
When ``dt`` is less than any element in ``dts``, ``last_before`` is None.
When ``dt`` is greater any element in ``dts``, ``first_after`` is None.
``dts`` must be unique and sorted in increasing order.
Parameters
----------
dts : pd.DatetimeIndex
Dates in which to search.
dt : pd.Timestamp
Date for which to find bounds.
"""
if not dts.is_unique:
raise ValueError("dts must be unique")
if not dts.is_monotonic_increasing:
raise ValueError("dts must be sorted in increasing order")
if not len(dts):
return None, None
sortpos = dts.searchsorted(dt, side='left')
try:
sortval = dts[sortpos]
except IndexError:
# dt is greater than any value in the array.
return dts[-1], None
if dt < sortval:
lower_ix = sortpos - 1
upper_ix = sortpos
elif dt == sortval:
lower_ix = sortpos - 1
upper_ix = sortpos + 1
else:
lower_ix = sortpos
upper_ix = sortpos + 1
lower_value = dts[lower_ix] if lower_ix >= 0 else None
upper_value = dts[upper_ix] if upper_ix < len(dts) else None
return lower_value, upper_value
+90
View File
@@ -0,0 +1,90 @@
"""
Shared docstrings for parameters that should be documented identically
across different functions.
"""
import re
from six import iteritems
from textwrap import dedent
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC = dedent(
"""\
frequency : {'year_start', 'quarter_start', 'month_start', 'week_start'}
A string indicating desired sampling dates:
'year_start' -> first trading day of each year
'quarter_start' -> first trading day of January, April, July, October
'month_start' -> first trading day of each month
'week_start' -> first trading_day of each week
"""
)
def pad_lines(prefix, s):
"""Apply a prefix to each line in s."""
return '\n'.join(prefix + line for line in s.splitlines())
def format_docstring(owner_name, docstring, formatters):
"""
Template ``formatters`` into ``docstring``.
Parameters
----------
owner_name : str
The name of the function or class whose docstring is being templated.
Only used for error messages.
docstring : str
The docstring to template.
formatters : dict[str -> str]
Parameters for a a str.format() call on ``docstring``.
Multi-line values in ``formatters`` will have leading whitespace padded
to match the leading whitespace of the substitution string.
"""
# Build a dict of parameters to a vanilla format() call by searching for
# each entry in **formatters and applying any leading whitespace to each
# line in the desired substitution.
format_params = {}
for target, doc_for_target in iteritems(formatters):
# Search for '{name}', with optional leading whitespace.
regex = re.compile('^(\s*)' + '({' + target + '})$', re.MULTILINE)
matches = regex.findall(docstring)
if not matches:
raise ValueError(
"Couldn't find template for parameter {!r} in docstring "
"for {}."
"\nParameter name must be alone on a line surrounded by "
"braces.".format(target, owner_name),
)
elif len(matches) > 1:
raise ValueError(
"Couldn't found multiple templates for parameter {!r}"
"in docstring for {}."
"\nParameter should only appear once.".format(
target, owner_name
)
)
(leading_whitespace, _) = matches[0]
format_params[target] = pad_lines(leading_whitespace, doc_for_target)
return docstring.format(**format_params)
def templated_docstring(**docs):
"""
Decorator allowing the use of templated docstrings.
Usage
-----
>>> @templated_docstring(foo='bar')
... def my_func(self, foo):
... '''{foo}'''
...
>>> my_func.__doc__
'bar'
"""
def decorator(f):
f.__doc__ = format_docstring(f.__name__, f.__doc__, docs)
return f
return decorator