mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 18:27:14 +08:00
Merge pull request #1394 from quantopian/downsample
Add Generic Downsampling to Pipeline
This commit is contained in:
+1
-1
@@ -101,7 +101,7 @@ install:
|
||||
- pip freeze | sort
|
||||
|
||||
test_script:
|
||||
- nosetests
|
||||
- nosetests -e zipline.utils.numpy_utils
|
||||
- flake8 zipline tests
|
||||
|
||||
branches:
|
||||
|
||||
+41
-31
@@ -1,24 +1,23 @@
|
||||
"""
|
||||
Base class for Pipeline API unittests.
|
||||
Base class for Pipeline API unit tests.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from numpy import arange, prod
|
||||
from pandas import date_range, Int64Index, DataFrame
|
||||
from pandas import DataFrame, Timestamp
|
||||
from six import iteritems
|
||||
|
||||
from zipline.assets.synthetic import make_simple_equity_info
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline import TermGraph
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.pipeline import ExecutionPlan
|
||||
from zipline.pipeline.term import AssetExists, InputDates
|
||||
from zipline.testing import (
|
||||
check_arrays,
|
||||
ExplodingObject,
|
||||
tmp_asset_finder,
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithTradingCalendars,
|
||||
WithAssetFinder,
|
||||
WithTradingSessions,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
|
||||
@@ -53,32 +52,26 @@ def with_defaults(**default_funcs):
|
||||
with_default_shape = with_defaults(shape=lambda self: self.default_shape)
|
||||
|
||||
|
||||
class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
class BasePipelineTestCase(WithTradingSessions,
|
||||
WithAssetFinder,
|
||||
ZiplineTestCase):
|
||||
START_DATE = Timestamp('2014', tz='UTC')
|
||||
END_DATE = Timestamp('2014-12-31', tz='UTC')
|
||||
ASSET_FINDER_EQUITY_SIDS = list(range(20))
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(BasePipelineTestCase, cls).init_class_fixtures()
|
||||
|
||||
cls.__calendar = date_range('2014', '2015',
|
||||
freq=cls.trading_calendar.day)
|
||||
cls.__assets = assets = Int64Index(arange(1, 20))
|
||||
cls.__tmp_finder_ctx = tmp_asset_finder(
|
||||
equities=make_simple_equity_info(
|
||||
assets,
|
||||
cls.__calendar[0],
|
||||
cls.__calendar[-1],
|
||||
)
|
||||
)
|
||||
cls.__finder = cls.__tmp_finder_ctx.__enter__()
|
||||
cls.__mask = cls.__finder.lifetimes(
|
||||
cls.__calendar[-30:],
|
||||
cls.default_asset_exists_mask = cls.asset_finder.lifetimes(
|
||||
cls.nyse_sessions[-30:],
|
||||
include_start_date=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
"""Default shape for methods that build test data."""
|
||||
return self.__mask.shape
|
||||
return self.default_asset_exists_mask.shape
|
||||
|
||||
def run_graph(self, graph, initial_workspace, mask=None):
|
||||
"""
|
||||
@@ -103,14 +96,17 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
"""
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: ExplodingObject(),
|
||||
self.__calendar,
|
||||
self.__finder,
|
||||
self.nyse_sessions,
|
||||
self.asset_finder,
|
||||
)
|
||||
if mask is None:
|
||||
mask = self.__mask
|
||||
mask = self.default_asset_exists_mask
|
||||
|
||||
dates, assets, mask_values = explode(mask)
|
||||
|
||||
initial_workspace.setdefault(AssetExists(), mask_values)
|
||||
initial_workspace.setdefault(InputDates(), dates)
|
||||
|
||||
return engine.compute_chunk(
|
||||
graph,
|
||||
dates,
|
||||
@@ -118,15 +114,29 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
initial_workspace,
|
||||
)
|
||||
|
||||
def check_terms(self, terms, expected, initial_workspace, mask):
|
||||
def check_terms(self,
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace,
|
||||
mask,
|
||||
check=check_arrays):
|
||||
"""
|
||||
Compile the given terms into a TermGraph, compute it with
|
||||
initial_workspace, and compare the results with ``expected``.
|
||||
"""
|
||||
graph = TermGraph(terms)
|
||||
start_date, end_date = mask.index[[0, -1]]
|
||||
graph = ExecutionPlan(
|
||||
terms,
|
||||
all_dates=self.nyse_sessions,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
results = self.run_graph(graph, initial_workspace, mask)
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_arrays(res, exp)
|
||||
check(res, exp)
|
||||
|
||||
return results
|
||||
|
||||
def build_mask(self, array):
|
||||
"""
|
||||
@@ -138,13 +148,13 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
array,
|
||||
# Use the **last** N dates rather than the first N so that we have
|
||||
# space for lookbacks.
|
||||
index=self.__calendar[-ndates:],
|
||||
columns=self.__assets[:nassets],
|
||||
index=self.nyse_sessions[-ndates:],
|
||||
columns=self.ASSET_FINDER_EQUITY_SIDS[:nassets],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
@with_default_shape
|
||||
def arange_data(self, shape, dtype=float):
|
||||
def arange_data(self, shape, dtype=np.float64):
|
||||
"""
|
||||
Build a block of testing data from numpy.arange.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,686 @@
|
||||
"""
|
||||
Tests for Downsampled Filters/Factors/Classifiers
|
||||
"""
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from zipline.pipeline import (
|
||||
Pipeline,
|
||||
CustomFactor,
|
||||
CustomFilter,
|
||||
CustomClassifier,
|
||||
)
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
from zipline.pipeline.factors import SimpleMovingAverage
|
||||
from zipline.pipeline.filters.smoothing import All
|
||||
from zipline.testing import ZiplineTestCase, parameter_space
|
||||
from zipline.testing.fixtures import (
|
||||
WithTradingSessions,
|
||||
WithSeededRandomPipelineEngine,
|
||||
)
|
||||
from zipline.utils.input_validation import _qualified_name
|
||||
from zipline.utils.numpy_utils import int64_dtype
|
||||
|
||||
|
||||
class NDaysAgoFactor(CustomFactor):
|
||||
inputs = [TestingDataSet.float_col]
|
||||
|
||||
def compute(self, today, assets, out, floats):
|
||||
out[:] = floats[0]
|
||||
|
||||
|
||||
class NDaysAgoFilter(CustomFilter):
|
||||
inputs = [TestingDataSet.bool_col]
|
||||
|
||||
def compute(self, today, assets, out, bools):
|
||||
out[:] = bools[0]
|
||||
|
||||
|
||||
class NDaysAgoClassifier(CustomClassifier):
|
||||
inputs = [TestingDataSet.categorical_col]
|
||||
dtype = TestingDataSet.categorical_col.dtype
|
||||
|
||||
def compute(self, today, assets, out, cats):
|
||||
out[:] = cats[0]
|
||||
|
||||
|
||||
class ComputeExtraRowsTestcase(WithTradingSessions, ZiplineTestCase):
|
||||
|
||||
DATA_MIN_DAY = pd.Timestamp('2012-06', tz='UTC')
|
||||
DATA_MAX_DAY = pd.Timestamp('2015', tz='UTC')
|
||||
TRADING_CALENDAR_STRS = ('NYSE',)
|
||||
|
||||
# Test with different window_lengths to ensure that window length is not
|
||||
# used when calculating exra rows for the top-level term.
|
||||
factor1 = TestingDataSet.float_col.latest
|
||||
factor11 = NDaysAgoFactor(window_length=11)
|
||||
factor91 = NDaysAgoFactor(window_length=91)
|
||||
|
||||
filter1 = TestingDataSet.bool_col.latest
|
||||
filter11 = NDaysAgoFilter(window_length=11)
|
||||
filter91 = NDaysAgoFilter(window_length=91)
|
||||
|
||||
classifier1 = TestingDataSet.categorical_col.latest
|
||||
classifier11 = NDaysAgoClassifier(window_length=11)
|
||||
classifier91 = NDaysAgoClassifier(window_length=91)
|
||||
|
||||
all_terms = [
|
||||
factor1,
|
||||
factor11,
|
||||
factor91,
|
||||
filter1,
|
||||
filter11,
|
||||
filter91,
|
||||
classifier1,
|
||||
classifier11,
|
||||
classifier91,
|
||||
]
|
||||
|
||||
@parameter_space(
|
||||
calendar_name=TRADING_CALENDAR_STRS,
|
||||
base_terms=[
|
||||
(factor1, factor11, factor91),
|
||||
(filter1, filter11, filter91),
|
||||
(classifier1, classifier11, classifier91),
|
||||
],
|
||||
__fail_fast=True
|
||||
)
|
||||
def test_yearly(self, base_terms, calendar_name):
|
||||
downsampled_terms = tuple(
|
||||
t.downsample('year_start') for t in base_terms
|
||||
)
|
||||
all_terms = base_terms + downsampled_terms
|
||||
|
||||
all_sessions = self.trading_sessions[calendar_name]
|
||||
end_session = all_sessions[-1]
|
||||
|
||||
years = all_sessions.year
|
||||
sessions_in_2012 = all_sessions[years == 2012]
|
||||
sessions_in_2013 = all_sessions[years == 2013]
|
||||
sessions_in_2014 = all_sessions[years == 2014]
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the first date in 2014. We shouldn't request any
|
||||
# additional rows for the regular terms or the downsampled terms.
|
||||
for i in range(0, 30, 5):
|
||||
start_session = sessions_in_2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
all_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the second date in 2014. We should request one more extra
|
||||
# row in the downsampled terms to push us back to the first date in
|
||||
# 2014.
|
||||
for i in range(0, 30, 5):
|
||||
start_session = sessions_in_2014[i + 1]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the last date of 2013. The downsampled terms should request
|
||||
# enough extra rows to push us back to the start of 2013.
|
||||
for i in range(0, 30, 5):
|
||||
start_session = sessions_in_2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(sessions_in_2013),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the last date of 2012. The downsampled terms should request
|
||||
# enough extra rows to push us back to the first known date, which is
|
||||
# in the middle of 2012
|
||||
for i in range(0, 30, 5):
|
||||
start_session = sessions_in_2013[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(sessions_in_2012),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
calendar_name=TRADING_CALENDAR_STRS,
|
||||
base_terms=[
|
||||
(factor1, factor11, factor91),
|
||||
(filter1, filter11, filter91),
|
||||
(classifier1, classifier11, classifier91),
|
||||
],
|
||||
__fail_fast=True
|
||||
)
|
||||
def test_quarterly(self, calendar_name, base_terms):
|
||||
downsampled_terms = tuple(
|
||||
t.downsample('quarter_start') for t in base_terms
|
||||
)
|
||||
all_terms = base_terms + downsampled_terms
|
||||
|
||||
# This region intersects with Q4 2013, Q1 2014, and Q2 2014.
|
||||
tmp = self.trading_sessions[calendar_name]
|
||||
all_sessions = tmp[tmp.slice_indexer('2013-12-15', '2014-04-30')]
|
||||
end_session = all_sessions[-1]
|
||||
|
||||
months = all_sessions.month
|
||||
Q4_2013 = all_sessions[months == 12]
|
||||
Q1_2014 = all_sessions[(months == 1) | (months == 2) | (months == 3)]
|
||||
Q2_2014 = all_sessions[months == 4]
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the first date in Q2 2014. We shouldn't request any
|
||||
# additional rows for the regular terms or the downsampled terms.
|
||||
for i in range(0, 15, 5):
|
||||
start_session = Q2_2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
all_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the second date in Q2 2014.
|
||||
# The downsampled terms should request one more extra row.
|
||||
for i in range(0, 15, 5):
|
||||
start_session = Q2_2014[i + 1]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the last date in Q1 2014. The downsampled terms
|
||||
# should request enough extra rows to push us back to the first date of
|
||||
# Q1 2014.
|
||||
for i in range(0, 15, 5):
|
||||
start_session = Q2_2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(Q1_2014),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the last date in Q4 2013. The downsampled terms
|
||||
# should request enough extra rows to push us back to the first known
|
||||
# date, which is in the middle of december 2013.
|
||||
for i in range(0, 15, 5):
|
||||
start_session = Q1_2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(Q4_2013),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
calendar_name=TRADING_CALENDAR_STRS,
|
||||
base_terms=[
|
||||
(factor1, factor11, factor91),
|
||||
(filter1, filter11, filter91),
|
||||
(classifier1, classifier11, classifier91),
|
||||
],
|
||||
__fail_fast=True
|
||||
)
|
||||
def test_monthly(self, calendar_name, base_terms):
|
||||
downsampled_terms = tuple(
|
||||
t.downsample('month_start') for t in base_terms
|
||||
)
|
||||
all_terms = base_terms + downsampled_terms
|
||||
|
||||
# This region intersects with Dec 2013, Jan 2014, and Feb 2014.
|
||||
tmp = self.trading_sessions[calendar_name]
|
||||
all_sessions = tmp[tmp.slice_indexer('2013-12-15', '2014-02-28')]
|
||||
end_session = all_sessions[-1]
|
||||
|
||||
months = all_sessions.month
|
||||
dec2013 = all_sessions[months == 12]
|
||||
jan2014 = all_sessions[months == 1]
|
||||
feb2014 = all_sessions[months == 2]
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the first date in feb 2014. We shouldn't request any
|
||||
# additional rows for the regular terms or the downsampled terms.
|
||||
for i in range(0, 10, 2):
|
||||
start_session = feb2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
all_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the second date in feb 2014. We should request one more
|
||||
# extra row in the downsampled terms to push us back to the first date
|
||||
# in 2014.
|
||||
for i in range(0, 10, 2):
|
||||
start_session = feb2014[i + 1]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the last date of jan 2014. The downsampled terms should
|
||||
# request enough extra rows to push us back to the start of jan 2014.
|
||||
for i in range(0, 10, 2):
|
||||
start_session = feb2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(jan2014),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land on the last date of dec 2013. The downsampled terms should
|
||||
# request enough extra rows to push us back to the first known date,
|
||||
# which is in the middle of december 2013.
|
||||
for i in range(0, 10, 2):
|
||||
start_session = jan2014[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(dec2013),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
calendar_name=TRADING_CALENDAR_STRS,
|
||||
base_terms=[
|
||||
(factor1, factor11, factor91),
|
||||
(filter1, filter11, filter91),
|
||||
(classifier1, classifier11, classifier91),
|
||||
],
|
||||
__fail_fast=True
|
||||
)
|
||||
def test_weekly(self, calendar_name, base_terms):
|
||||
downsampled_terms = tuple(
|
||||
t.downsample('week_start') for t in base_terms
|
||||
)
|
||||
all_terms = base_terms + downsampled_terms
|
||||
|
||||
# December 2013
|
||||
# Mo Tu We Th Fr Sa Su
|
||||
# 1
|
||||
# 2 3 4 5 6 7 8
|
||||
# 9 10 11 12 13 14 15
|
||||
# 16 17 18 19 20 21 22
|
||||
# 23 24 25 26 27 28 29
|
||||
# 30 31
|
||||
|
||||
# January 2014
|
||||
# Mo Tu We Th Fr Sa Su
|
||||
# 1 2 3 4 5
|
||||
# 6 7 8 9 10 11 12
|
||||
# 13 14 15 16 17 18 19
|
||||
# 20 21 22 23 24 25 26
|
||||
# 27 28 29 30 31
|
||||
|
||||
# This region intersects with the last full week of 2013, the week
|
||||
# shared by 2013 and 2014, and the first full week of 2014.
|
||||
tmp = self.trading_sessions[calendar_name]
|
||||
all_sessions = tmp[tmp.slice_indexer('2013-12-27', '2014-01-12')]
|
||||
end_session = all_sessions[-1]
|
||||
|
||||
week0 = all_sessions[
|
||||
all_sessions.slice_indexer('2013-12-27', '2013-12-29')
|
||||
]
|
||||
week1 = all_sessions[
|
||||
all_sessions.slice_indexer('2013-12-30', '2014-01-05')
|
||||
]
|
||||
week2 = all_sessions[
|
||||
all_sessions.slice_indexer('2014-01-06', '2014-01-12')
|
||||
]
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the first date in week 2. We shouldn't request any
|
||||
# additional rows for the regular terms or the downsampled terms.
|
||||
for i in range(3):
|
||||
start_session = week2[i]
|
||||
self.check_extra_row_calculations(
|
||||
all_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the second date in week 2. The downsampled terms
|
||||
# should request one more extra row.
|
||||
for i in range(3):
|
||||
start_session = week2[i + 1]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i,
|
||||
expected_extra_rows=i,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the last date in week 1. The downsampled terms
|
||||
# should request enough extra rows to push us back to the first date of
|
||||
# week 1.
|
||||
for i in range(3):
|
||||
start_session = week2[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(week1),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
# Simulate requesting computation where the unaltered lookback would
|
||||
# land exactly on the last date in week0. The downsampled terms
|
||||
# should request enough extra rows to push us back to the first known
|
||||
# date, which is in the middle of december 2013.
|
||||
for i in range(3):
|
||||
start_session = week1[i]
|
||||
self.check_extra_row_calculations(
|
||||
downsampled_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + len(week0),
|
||||
)
|
||||
self.check_extra_row_calculations(
|
||||
base_terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows=i + 1,
|
||||
expected_extra_rows=i + 1,
|
||||
)
|
||||
|
||||
def check_extra_row_calculations(self,
|
||||
terms,
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows,
|
||||
expected_extra_rows):
|
||||
"""
|
||||
Check that each term in ``terms`` computes an expected number of extra
|
||||
rows for the given parameters.
|
||||
"""
|
||||
for term in terms:
|
||||
result = term.compute_extra_rows(
|
||||
all_sessions,
|
||||
start_session,
|
||||
end_session,
|
||||
min_extra_rows,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
expected_extra_rows,
|
||||
"Expected {} extra_rows from {}, but got {}.".format(
|
||||
expected_extra_rows,
|
||||
term,
|
||||
result,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DownsampledPipelineTestCase(WithSeededRandomPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
# Extend into the last few days of 2013 to test year/quarter boundaries.
|
||||
START_DATE = pd.Timestamp('2013-12-15', tz='UTC')
|
||||
|
||||
# Extend into the first few days of 2015 to test year/quarter boundaries.
|
||||
END_DATE = pd.Timestamp('2015-01-06', tz='UTC')
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = tuple(range(10))
|
||||
|
||||
def check_downsampled_term(self, term):
|
||||
|
||||
# June 2014
|
||||
# Mo Tu We Th Fr Sa Su
|
||||
# 1
|
||||
# 2 3 4 5 6 7 8
|
||||
# 9 10 11 12 13 14 15
|
||||
# 16 17 18 19 20 21 22
|
||||
# 23 24 25 26 27 28 29
|
||||
# 30
|
||||
all_sessions = self.nyse_sessions
|
||||
compute_dates = all_sessions[
|
||||
all_sessions.slice_indexer('2014-06-05', '2015-01-06')
|
||||
]
|
||||
start_date, end_date = compute_dates[[0, -1]]
|
||||
|
||||
pipe = Pipeline({
|
||||
'year': term.downsample(frequency='year_start'),
|
||||
'quarter': term.downsample(frequency='quarter_start'),
|
||||
'month': term.downsample(frequency='month_start'),
|
||||
'week': term.downsample(frequency='week_start'),
|
||||
})
|
||||
|
||||
# Raw values for term, computed each day from 2014 to the end of the
|
||||
# target period.
|
||||
raw_term_results = self.run_pipeline(
|
||||
Pipeline({'term': term}),
|
||||
start_date=pd.Timestamp('2014-01-02', tz='UTC'),
|
||||
end_date=pd.Timestamp('2015-01-06', tz='UTC'),
|
||||
)['term'].unstack()
|
||||
|
||||
expected_results = {
|
||||
'year': (raw_term_results
|
||||
.groupby(pd.TimeGrouper('AS'))
|
||||
.first()
|
||||
.reindex(compute_dates, method='ffill')),
|
||||
'quarter': (raw_term_results
|
||||
.groupby(pd.TimeGrouper('QS'))
|
||||
.first()
|
||||
.reindex(compute_dates, method='ffill')),
|
||||
'month': (raw_term_results
|
||||
.groupby(pd.TimeGrouper('MS'))
|
||||
.first()
|
||||
.reindex(compute_dates, method='ffill')),
|
||||
'week': (raw_term_results
|
||||
.groupby(pd.TimeGrouper('W', label='left'))
|
||||
.first()
|
||||
.reindex(compute_dates, method='ffill')),
|
||||
}
|
||||
|
||||
results = self.run_pipeline(pipe, start_date, end_date)
|
||||
|
||||
for frequency in expected_results:
|
||||
result = results[frequency].unstack()
|
||||
expected = expected_results[frequency]
|
||||
assert_frame_equal(result, expected)
|
||||
|
||||
def test_downsample_windowed_factor(self):
|
||||
self.check_downsampled_term(
|
||||
SimpleMovingAverage(
|
||||
inputs=[TestingDataSet.float_col],
|
||||
window_length=5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_downsample_non_windowed_factor(self):
|
||||
sma = SimpleMovingAverage(
|
||||
inputs=[TestingDataSet.float_col],
|
||||
window_length=5,
|
||||
)
|
||||
|
||||
self.check_downsampled_term(((sma + sma) / 2).rank())
|
||||
|
||||
def test_downsample_windowed_filter(self):
|
||||
sma = SimpleMovingAverage(
|
||||
inputs=[TestingDataSet.float_col],
|
||||
window_length=5,
|
||||
)
|
||||
self.check_downsampled_term(All(inputs=[sma.top(4)], window_length=5))
|
||||
|
||||
def test_downsample_nonwindowed_filter(self):
|
||||
sma = SimpleMovingAverage(
|
||||
inputs=[TestingDataSet.float_col],
|
||||
window_length=5,
|
||||
)
|
||||
self.check_downsampled_term(sma > 5)
|
||||
|
||||
def test_downsample_windowed_classifier(self):
|
||||
|
||||
class IntSumClassifier(CustomClassifier):
|
||||
inputs = [TestingDataSet.float_col]
|
||||
window_length = 8
|
||||
dtype = int64_dtype
|
||||
missing_value = -1
|
||||
|
||||
def compute(self, today, assets, out, floats):
|
||||
out[:] = floats.sum(axis=0).astype(int) % 4
|
||||
|
||||
self.check_downsampled_term(IntSumClassifier())
|
||||
|
||||
def test_downsample_nonwindowed_classifier(self):
|
||||
sma = SimpleMovingAverage(
|
||||
inputs=[TestingDataSet.float_col],
|
||||
window_length=5,
|
||||
)
|
||||
self.check_downsampled_term(sma.quantiles(5))
|
||||
|
||||
def test_errors_on_bad_downsample_frequency(self):
|
||||
|
||||
f = NDaysAgoFactor(window_length=3)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
f.downsample('bad')
|
||||
|
||||
expected = (
|
||||
"{}() expected a value in "
|
||||
"('month_start', 'quarter_start', 'week_start', 'year_start') "
|
||||
"for argument 'frequency', but got 'bad' instead."
|
||||
).format(_qualified_name(f.downsample))
|
||||
self.assertEqual(str(e.exception), expected)
|
||||
@@ -40,6 +40,7 @@ from six import iteritems, itervalues
|
||||
from toolz import merge
|
||||
|
||||
from zipline.assets.synthetic import make_rotating_equity_info
|
||||
from zipline.errors import NoFurtherDataError
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.pipeline import CustomFactor, Pipeline
|
||||
@@ -65,6 +66,7 @@ from zipline.pipeline.loaders.synthetic import (
|
||||
expected_bar_values_2d,
|
||||
)
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.pipeline.term import InputDates
|
||||
from zipline.testing import (
|
||||
AssetID,
|
||||
AssetIDPlusDay,
|
||||
@@ -81,7 +83,7 @@ from zipline.testing.fixtures import (
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
from zipline.utils.numpy_utils import bool_dtype, datetime64ns_dtype
|
||||
|
||||
|
||||
class RollingSumDifference(CustomFactor):
|
||||
@@ -206,6 +208,53 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
engine.run_pipeline(p, self.dates[2], self.dates[1])
|
||||
|
||||
def test_fail_usefully_on_insufficient_data(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
class SomeFactor(CustomFactor):
|
||||
inputs = [USEquityPricing.close]
|
||||
window_length = 10
|
||||
|
||||
def compute(self, today, assets, out, closes):
|
||||
pass
|
||||
|
||||
p = Pipeline(columns={'t': SomeFactor()})
|
||||
|
||||
# self.dates[9] is the earliest date we should be able to compute.
|
||||
engine.run_pipeline(p, self.dates[9], self.dates[9])
|
||||
|
||||
# We shouldn't be able to compute dates[8], since we only know about 8
|
||||
# prior dates, and we need a window length of 10.
|
||||
with self.assertRaises(NoFurtherDataError):
|
||||
engine.run_pipeline(p, self.dates[8], self.dates[8])
|
||||
|
||||
def test_input_dates_provided_by_default(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
class TestFactor(CustomFactor):
|
||||
inputs = [InputDates(), USEquityPricing.close]
|
||||
window_length = 10
|
||||
dtype = datetime64ns_dtype
|
||||
|
||||
def compute(self, today, assets, out, dates, closes):
|
||||
first, last = dates[[0, -1], 0]
|
||||
assert last == today.asm8
|
||||
assert len(dates) == len(closes) == self.window_length
|
||||
out[:] = first
|
||||
|
||||
p = Pipeline(columns={'t': TestFactor()})
|
||||
results = engine.run_pipeline(p, self.dates[9], self.dates[10])
|
||||
|
||||
# All results are the same, so just grab one column.
|
||||
column = results.unstack().iloc[:, 0].values
|
||||
check_arrays(column, self.dates[:2].values)
|
||||
|
||||
def test_same_day_pipeline(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(
|
||||
|
||||
@@ -26,7 +26,7 @@ from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.lib.rank import masked_rankdata_2d
|
||||
from zipline.lib.normalize import naive_grouped_rowwise_apply as grouped_apply
|
||||
from zipline.pipeline import Classifier, Factor, Filter, TermGraph
|
||||
from zipline.pipeline import Classifier, Factor, Filter
|
||||
from zipline.pipeline.factors import (
|
||||
Returns,
|
||||
RSI,
|
||||
@@ -37,7 +37,6 @@ from zipline.testing import (
|
||||
parameter_space,
|
||||
permute_rows,
|
||||
)
|
||||
from zipline.utils.functional import dzip_exact
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
datetime64ns_dtype,
|
||||
@@ -123,20 +122,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
data = arange(25).reshape(5, 5)
|
||||
data[eye(5, dtype=bool)] = custom_missing_value
|
||||
|
||||
graph = TermGraph(
|
||||
self.check_terms(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
},
|
||||
{
|
||||
'isnull': eye(5, dtype=bool),
|
||||
'notnull': ~eye(5, dtype=bool),
|
||||
},
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
def test_isnull_datetime_dtype(self):
|
||||
class DatetimeFactor(Factor):
|
||||
@@ -149,20 +146,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
data = arange(25).reshape(5, 5).astype('datetime64[ns]')
|
||||
data[eye(5, dtype=bool)] = NaTns
|
||||
|
||||
graph = TermGraph(
|
||||
self.check_terms(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
},
|
||||
{
|
||||
'isnull': eye(5, dtype=bool),
|
||||
'notnull': ~eye(5, dtype=bool),
|
||||
},
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
@for_each_factor_dtype
|
||||
def test_rank_ascending(self, name, factor_dtype):
|
||||
@@ -206,14 +201,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({meth: f.rank(method=meth) for meth in expected_ranks})
|
||||
check({
|
||||
@@ -265,14 +258,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({
|
||||
meth: f.rank(method=meth, ascending=False)
|
||||
@@ -294,14 +285,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask_data = ~eye(5, dtype=bool)
|
||||
initial_workspace = {f: data, Mask(): mask_data}
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
"ascending_nomask": f.rank(ascending=True),
|
||||
"ascending_mask": f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": f.rank(ascending=False),
|
||||
"descending_mask": f.rank(ascending=False, mask=Mask()),
|
||||
}
|
||||
)
|
||||
terms = {
|
||||
"ascending_nomask": f.rank(ascending=True),
|
||||
"ascending_mask": f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": f.rank(ascending=False),
|
||||
"descending_mask": f.rank(ascending=False, mask=Mask()),
|
||||
}
|
||||
|
||||
expected = {
|
||||
"ascending_nomask": array([[1., 3., 4., 5., 2.],
|
||||
@@ -328,13 +317,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[4., 3., 2., 1., nan]]),
|
||||
}
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace,
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in results:
|
||||
check_arrays(expected[method], results[method])
|
||||
|
||||
@for_each_factor_dtype
|
||||
def test_grouped_rank_ascending(self, name, factor_dtype=float64_dtype):
|
||||
@@ -363,7 +351,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
expected_ranks = {
|
||||
'ordinal': array(
|
||||
[[1., 1., 3., 2., 2.],
|
||||
[1., 2., 3., 1., 2.],
|
||||
@@ -402,9 +390,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
@@ -413,25 +401,22 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
# Not specifying the value of ascending param should default to True
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
@@ -468,7 +453,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
expected_ranks = {
|
||||
'ordinal': array(
|
||||
[[2., 2., 1., 1., 3.],
|
||||
[2., 1., 1., 2., 3.],
|
||||
@@ -507,9 +492,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
@@ -518,16 +503,13 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
@@ -707,9 +689,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
expected['grouped_str'] = expected['grouped']
|
||||
expected['grouped_masked_str'] = expected['grouped_masked']
|
||||
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
c: classifier_data,
|
||||
@@ -717,20 +699,13 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
m: filter_data,
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
|
||||
# The hand-computed values aren't very precise (in particular,
|
||||
# we truncate repeating decimals at 3 places) This is just
|
||||
# asserting that the example isn't misleading by being totally
|
||||
# wrong.
|
||||
check=partial(check_allclose, atol=0.001),
|
||||
)
|
||||
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_allclose(
|
||||
res,
|
||||
exp,
|
||||
# The hand-computed values aren't very precise (in particular,
|
||||
# we truncate repeating decimals at 3 places) This is just
|
||||
# asserting that the example isn't misleading by being totally
|
||||
# wrong.
|
||||
atol=0.001,
|
||||
err_msg="Mismatch for %r" % key
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
seed_value=range(1, 2),
|
||||
normalizer_name_and_func=[
|
||||
|
||||
+156
-179
@@ -12,7 +12,6 @@ from numpy import (
|
||||
array,
|
||||
eye,
|
||||
float64,
|
||||
full_like,
|
||||
full,
|
||||
inf,
|
||||
isfinite,
|
||||
@@ -27,11 +26,11 @@ from numpy import (
|
||||
from numpy.random import randn, seed as random_seed
|
||||
|
||||
from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, TermGraph
|
||||
from zipline.pipeline import Filter, Factor
|
||||
from zipline.pipeline.classifiers import Classifier
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.filters import All, Any, AtLeastN
|
||||
from zipline.testing import check_arrays, parameter_space, permute_rows
|
||||
from zipline.testing import parameter_space, permute_rows
|
||||
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
|
||||
@@ -127,7 +126,6 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
nan_data[:, 0] = nan
|
||||
|
||||
mask = Mask()
|
||||
workspace = {self.f: data, mask: mask_data}
|
||||
|
||||
methods = ['top', 'bottom']
|
||||
counts = 2, 3, 10
|
||||
@@ -136,18 +134,6 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
def termname(method, count, masked):
|
||||
return '_'.join([method, str(count), 'mask' if masked else ''])
|
||||
|
||||
# Add a term for each permutation of top/bottom, count, and
|
||||
# mask/no_mask.
|
||||
terms = {}
|
||||
for method, count, masked in term_combos:
|
||||
kwargs = {'N': count}
|
||||
if masked:
|
||||
kwargs['mask'] = mask
|
||||
term = getattr(self.f, method)(**kwargs)
|
||||
terms[termname(method, count, masked)] = term
|
||||
|
||||
results = self.run_graph(TermGraph(terms), initial_workspace=workspace)
|
||||
|
||||
def expected_result(method, count, masked):
|
||||
# Ranking with a mask is equivalent to ranking with nans applied on
|
||||
# the masked values.
|
||||
@@ -158,72 +144,55 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
elif method == 'bottom':
|
||||
return rowwise_rank(to_rank) < count
|
||||
|
||||
# Add a term for each permutation of top/bottom, count, and
|
||||
# mask/no_mask.
|
||||
terms = {}
|
||||
expected = {}
|
||||
for method, count, masked in term_combos:
|
||||
result = results[termname(method, count, masked)]
|
||||
kwargs = {'N': count}
|
||||
if masked:
|
||||
kwargs['mask'] = mask
|
||||
term = getattr(self.f, method)(**kwargs)
|
||||
name = termname(method, count, masked)
|
||||
terms[name] = term
|
||||
expected[name] = expected_result(method, count, masked)
|
||||
|
||||
# Check that `min(c, num_assets)` assets passed each day.
|
||||
passed_per_day = result.sum(axis=1)
|
||||
check_arrays(
|
||||
passed_per_day,
|
||||
full_like(passed_per_day, min(count, data.shape[1])),
|
||||
)
|
||||
|
||||
expected = expected_result(method, count, masked)
|
||||
check_arrays(result, expected)
|
||||
|
||||
def test_bottom(self):
|
||||
counts = 2, 3, 10
|
||||
data = self.randn_data(seed=5) # Arbitrary seed choice.
|
||||
results = self.run_graph(
|
||||
TermGraph(
|
||||
{'bottom_' + str(c): self.f.bottom(c) for c in counts}
|
||||
),
|
||||
initial_workspace={self.f: data},
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: data, mask: mask_data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
for c in counts:
|
||||
result = results['bottom_' + str(c)]
|
||||
|
||||
# Check that `min(c, num_assets)` assets passed each day.
|
||||
passed_per_day = result.sum(axis=1)
|
||||
check_arrays(
|
||||
passed_per_day,
|
||||
full_like(passed_per_day, min(c, data.shape[1])),
|
||||
)
|
||||
|
||||
# Check that the bottom `c` assets passed.
|
||||
expected = rowwise_rank(data) < c
|
||||
check_arrays(result, expected)
|
||||
|
||||
def test_percentile_between(self):
|
||||
|
||||
quintiles = range(5)
|
||||
filter_names = ['pct_' + str(q) for q in quintiles]
|
||||
iter_quintiles = zip(filter_names, quintiles)
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
|
||||
for name, q in zip(filter_names, quintiles)
|
||||
}
|
||||
)
|
||||
iter_quintiles = list(zip(filter_names, quintiles))
|
||||
terms = {
|
||||
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
|
||||
for name, q in iter_quintiles
|
||||
}
|
||||
|
||||
# Test with 5 columns and no NaNs.
|
||||
eye5 = eye(5, dtype=float64)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye5},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# There are four 0s and one 1 in each row, so the first 4
|
||||
# quintiles should be all the locations with zeros in the input
|
||||
# array.
|
||||
check_arrays(result, ~eye5.astype(bool))
|
||||
expected[name] = ~eye5.astype(bool)
|
||||
else:
|
||||
# The top quintile should match the sole 1 in each row.
|
||||
check_arrays(result, eye5.astype(bool))
|
||||
expected[name] = eye5.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={self.f: eye5},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
# Test with 6 columns, no NaNs, and one masked entry per day.
|
||||
eye6 = eye(6, dtype=float64)
|
||||
@@ -233,41 +202,44 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
[1, 1, 0, 1, 1, 1],
|
||||
[1, 1, 1, 0, 1, 1],
|
||||
[1, 1, 1, 1, 0, 1]], dtype=bool)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask)
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# Should keep all values that were 0 in the base data and were
|
||||
# 1 in the mask.
|
||||
check_arrays(result, mask & (~eye6.astype(bool))),
|
||||
expected[name] = mask & ~eye6.astype(bool)
|
||||
else:
|
||||
# Should keep all the 1s in the base data.
|
||||
check_arrays(result, eye6.astype(bool))
|
||||
# The top quintile should match the sole 1 in each row.
|
||||
expected[name] = eye6.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
# Test with 6 columns, no mask, and one NaN per day. Should have the
|
||||
# same outcome as if we had masked the NaNs.
|
||||
# In particular, the NaNs should never pass any filters.
|
||||
eye6_withnans = eye6.copy()
|
||||
putmask(eye6_withnans, ~mask, nan)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask)
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# Should keep all values that were 0 in the base data and were
|
||||
# 1 in the mask.
|
||||
check_arrays(result, mask & (~eye6.astype(bool))),
|
||||
expected[name] = mask & (~eye6.astype(bool))
|
||||
else:
|
||||
# Should keep all the 1s in the base data.
|
||||
check_arrays(result, eye6.astype(bool))
|
||||
expected[name] = eye6.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
def test_percentile_nasty_partitions(self):
|
||||
# Test percentile with nasty partitions: divide up 5 assets into
|
||||
@@ -281,27 +253,26 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
quartiles = range(4)
|
||||
filter_names = ['pct_' + str(q) for q in quartiles]
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
|
||||
for name, q in zip(filter_names, quartiles)
|
||||
}
|
||||
)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
terms = {
|
||||
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
|
||||
for name, q in zip(filter_names, quartiles)
|
||||
}
|
||||
|
||||
expected = {}
|
||||
for name, quartile in zip(filter_names, quartiles):
|
||||
result = results[name]
|
||||
lower = quartile * 25.0
|
||||
upper = (quartile + 1) * 25.0
|
||||
expected = and_(
|
||||
expected[name] = and_(
|
||||
nanpercentile(data, lower, axis=1, keepdims=True) <= data,
|
||||
data <= nanpercentile(data, upper, axis=1, keepdims=True),
|
||||
)
|
||||
check_arrays(result, expected)
|
||||
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
def test_percentile_after_mask(self):
|
||||
f_input = eye(5)
|
||||
@@ -312,77 +283,79 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
without_mask = self.g.percentile_between(80, 100)
|
||||
with_mask = self.g.percentile_between(80, 100, mask=custom_mask)
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
'custom_mask': custom_mask,
|
||||
'without': without_mask,
|
||||
'with': with_mask,
|
||||
}
|
||||
)
|
||||
terms = {
|
||||
'mask': custom_mask,
|
||||
'without_mask': without_mask,
|
||||
'with_mask': with_mask,
|
||||
}
|
||||
expected = {
|
||||
# Mask that accepts everything except the diagonal.
|
||||
'mask': ~eye(5, dtype=bool),
|
||||
# Second should pass the largest value each day. Each row is
|
||||
# strictly increasing, so we always select the last value.
|
||||
'without_mask': array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1]],
|
||||
dtype=bool,
|
||||
),
|
||||
# With a mask, we should remove the diagonal as an option before
|
||||
# computing percentiles. On the last day, we should get the
|
||||
# second-largest value, rather than the largest.
|
||||
'with_mask': array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0]], # Different from with!
|
||||
dtype=bool,
|
||||
),
|
||||
}
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: f_input, self.g: g_input},
|
||||
mask=initial_mask,
|
||||
)
|
||||
|
||||
# First should pass everything but the diagonal.
|
||||
check_arrays(results['custom_mask'], ~eye(5, dtype=bool))
|
||||
|
||||
# Second should pass the largest value each day. Each row is strictly
|
||||
# increasing, so we always select the last value.
|
||||
expected_without = array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1]],
|
||||
dtype=bool,
|
||||
)
|
||||
check_arrays(results['without'], expected_without)
|
||||
|
||||
# When sequencing, we should remove the diagonal as an option before
|
||||
# computing percentiles. On the last day, we should get the
|
||||
# second-largest value, rather than the largest.
|
||||
expected_with = array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0]], # Different from previous!
|
||||
dtype=bool,
|
||||
)
|
||||
check_arrays(results['with'], expected_with)
|
||||
|
||||
def test_isnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'isnan': self.f.isnan(),
|
||||
'isnull': self.f.isnull(),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'isnan': diag,
|
||||
'isnull': diag,
|
||||
},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['isnan'], diag)
|
||||
check_arrays(results['isnull'], diag)
|
||||
|
||||
def test_notnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'notnan': self.f.notnan(),
|
||||
'notnull': self.f.notnull(),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'notnan': ~diag,
|
||||
'notnull': ~diag,
|
||||
},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['notnan'], ~diag)
|
||||
check_arrays(results['notnull'], ~diag)
|
||||
|
||||
def test_isfinite(self):
|
||||
data = self.randn_data(seed=10)
|
||||
@@ -390,11 +363,12 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
data[:, 2] = inf
|
||||
data[:, 4] = -inf
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'isfinite': self.f.isfinite()}),
|
||||
self.check_terms(
|
||||
terms={'isfinite': self.f.isfinite()},
|
||||
expected={'isfinite': isfinite(data)},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_all(self):
|
||||
|
||||
@@ -427,18 +401,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'3': All(inputs=[Input()], window_length=3),
|
||||
'4': All(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'3': expected_3,
|
||||
'4': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_any(self):
|
||||
|
||||
# FUN FACT: The inputs and outputs here are exactly the negation of
|
||||
@@ -486,18 +461,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'3': Any(inputs=[Input()], window_length=3),
|
||||
'4': Any(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'3': expected_3,
|
||||
'4': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_at_least_N(self):
|
||||
|
||||
# With a window_length of K, AtLeastN should return 1
|
||||
@@ -553,26 +529,27 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
window_length=4,
|
||||
N=4)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'AllButOne': all_but_one,
|
||||
'AllButTwo': all_but_two,
|
||||
'AnyEquiv': any_equiv,
|
||||
'AllEquiv': all_equiv,
|
||||
'Any': Any(inputs=[Input()], window_length=4),
|
||||
'All': All(inputs=[Input()], window_length=4)
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'Any': expected_1,
|
||||
'AnyEquiv': expected_1,
|
||||
'AllButTwo': expected_2,
|
||||
'AllButOne': expected_3,
|
||||
'All': expected_4,
|
||||
'AllEquiv': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['Any'], expected_1)
|
||||
check_arrays(results['AnyEquiv'], expected_1)
|
||||
check_arrays(results['AllButTwo'], expected_2)
|
||||
check_arrays(results['AllButOne'], expected_3)
|
||||
check_arrays(results['All'], expected_4)
|
||||
check_arrays(results['AllEquiv'], expected_4)
|
||||
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
# all true data set of (days, securities)
|
||||
@@ -591,19 +568,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
# sum for each column
|
||||
out[:] = np_sum(filter_, axis=0)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'windowsafe': TestFactor()}),
|
||||
initial_workspace={InputFilter(): data},
|
||||
)
|
||||
|
||||
# number of days in default_shape
|
||||
n = self.default_shape[0]
|
||||
|
||||
# shape of output array
|
||||
output_shape = ((n - factor_len + 1), self.default_shape[1])
|
||||
check_arrays(
|
||||
results['windowsafe'],
|
||||
full(output_shape, factor_len, dtype=float64)
|
||||
full(output_shape, factor_len, dtype=float64)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'windowsafe': TestFactor(),
|
||||
},
|
||||
expected={
|
||||
'windowsafe': full(output_shape, factor_len, dtype=float64),
|
||||
},
|
||||
initial_workspace={InputFilter(): data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""
|
||||
Tests for zipline.pipeline.Pipeline
|
||||
"""
|
||||
import inspect
|
||||
from unittest import TestCase
|
||||
|
||||
from mock import patch
|
||||
|
||||
from zipline.pipeline import Factor, Filter, Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.graph import display_graph
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
|
||||
@@ -137,3 +141,57 @@ class PipelineTestCase(TestCase):
|
||||
"expected a value of type bool or int for argument 'overwrite'",
|
||||
message,
|
||||
)
|
||||
|
||||
def test_show_graph(self):
|
||||
f = SomeFactor()
|
||||
p = Pipeline(columns={'f': SomeFactor()})
|
||||
|
||||
# The real display_graph call shells out to GraphViz, which isn't a
|
||||
# requirement, so patch it out for testing.
|
||||
|
||||
def mock_display_graph(g, format='svg', include_asset_exists=False):
|
||||
return (g, format, include_asset_exists)
|
||||
|
||||
self.assertEqual(
|
||||
inspect.getargspec(display_graph),
|
||||
inspect.getargspec(mock_display_graph),
|
||||
msg="Mock signature doesn't match signature for display_graph."
|
||||
)
|
||||
|
||||
patch_display_graph = patch(
|
||||
'zipline.pipeline.graph.display_graph',
|
||||
mock_display_graph,
|
||||
)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph()
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'svg')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph(format='png')
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'png')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph(format='jpeg')
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'jpeg')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
expected = (
|
||||
r".*\.show_graph\(\) expected a value in "
|
||||
r"\('svg', 'png', 'jpeg'\) for argument 'format', "
|
||||
r"but got 'fizzbuzz' instead."
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected):
|
||||
p.show_graph(format='fizzbuzz')
|
||||
|
||||
@@ -7,10 +7,7 @@ import pandas as pd
|
||||
import talib
|
||||
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.pipeline import TermGraph
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.pipeline.factors import (
|
||||
BollingerBands,
|
||||
Aroon,
|
||||
@@ -20,61 +17,22 @@ from zipline.pipeline.factors import (
|
||||
RateOfChangePercentage,
|
||||
TrueRange,
|
||||
)
|
||||
from zipline.testing import ExplodingObject, parameter_space
|
||||
from zipline.testing.fixtures import WithAssetFinder, ZiplineTestCase
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import ZiplineTestCase
|
||||
from zipline.testing.predicates import assert_equal
|
||||
|
||||
|
||||
class WithTechnicalFactor(WithAssetFinder):
|
||||
"""ZiplineTestCase fixture for testing technical factors.
|
||||
"""
|
||||
ASSET_FINDER_EQUITY_SIDS = tuple(range(5))
|
||||
START_DATE = pd.Timestamp('2014-01-01', tz='utc')
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithTechnicalFactor, cls).init_class_fixtures()
|
||||
cls.ndays = ndays = 24
|
||||
cls.nassets = nassets = len(cls.ASSET_FINDER_EQUITY_SIDS)
|
||||
cls.dates = dates = pd.date_range(cls.START_DATE, periods=ndays)
|
||||
cls.assets = pd.Index(cls.asset_finder.sids)
|
||||
cls.engine = SimplePipelineEngine(
|
||||
lambda column: ExplodingObject(),
|
||||
dates,
|
||||
cls.asset_finder,
|
||||
)
|
||||
cls.asset_exists = exists = np.full((ndays, nassets), True, dtype=bool)
|
||||
cls.asset_exists_masked = masked = exists.copy()
|
||||
masked[:, -1] = False
|
||||
|
||||
def run_graph(self, graph, initial_workspace, mask_sid):
|
||||
initial_workspace.setdefault(
|
||||
AssetExists(),
|
||||
self.asset_exists_masked if mask_sid else self.asset_exists,
|
||||
)
|
||||
return self.engine.compute_chunk(
|
||||
graph,
|
||||
self.dates,
|
||||
self.assets,
|
||||
initial_workspace,
|
||||
)
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
|
||||
class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(BollingerBandsTestCase, cls).init_class_fixtures()
|
||||
cls._closes = closes = (
|
||||
np.arange(cls.ndays, dtype=float)[:, np.newaxis] +
|
||||
np.arange(cls.nassets, dtype=float) * 100
|
||||
)
|
||||
cls._closes_masked = masked = closes.copy()
|
||||
masked[:, -1] = np.nan
|
||||
class BollingerBandsTestCase(BasePipelineTestCase):
|
||||
|
||||
def closes(self, masked):
|
||||
return self._closes_masked if masked else self._closes
|
||||
def closes(self, mask_last_sid):
|
||||
data = self.arange_data(dtype=np.float64)
|
||||
if mask_last_sid:
|
||||
data[:, -1] = np.nan
|
||||
return data
|
||||
|
||||
def expected(self, window_length, k, closes):
|
||||
def expected_bbands(self, window_length, k, closes):
|
||||
"""Compute the expected data (without adjustments) for the given
|
||||
window, k, and closes array.
|
||||
|
||||
@@ -83,11 +41,14 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
lower_cols = []
|
||||
middle_cols = []
|
||||
upper_cols = []
|
||||
for n in range(self.nassets):
|
||||
|
||||
ndates, nassets = closes.shape
|
||||
|
||||
for n in range(nassets):
|
||||
close_col = closes[:, n]
|
||||
if np.isnan(close_col).all():
|
||||
# ta-lib doesn't deal well with all nans.
|
||||
upper, middle, lower = [np.full(self.ndays, np.nan)] * 3
|
||||
upper, middle, lower = [np.full(ndates, np.nan)] * 3
|
||||
else:
|
||||
upper, middle, lower = talib.BBANDS(
|
||||
close_col,
|
||||
@@ -112,38 +73,38 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
@parameter_space(
|
||||
window_length={5, 10, 20},
|
||||
k={1.5, 2, 2.5},
|
||||
mask_sid={True, False},
|
||||
mask_last_sid={True, False},
|
||||
__fail_fast=True,
|
||||
)
|
||||
def test_bollinger_bands(self, window_length, k, mask_sid):
|
||||
closes = self.closes(mask_sid)
|
||||
result = self.run_graph(
|
||||
TermGraph({
|
||||
'f': BollingerBands(
|
||||
window_length=window_length,
|
||||
k=k,
|
||||
),
|
||||
}),
|
||||
def test_bollinger_bands(self, window_length, k, mask_last_sid):
|
||||
closes = self.closes(mask_last_sid=mask_last_sid)
|
||||
mask = ~np.isnan(closes)
|
||||
bbands = BollingerBands(window_length=window_length, k=k)
|
||||
|
||||
expected = self.expected_bbands(window_length, k, closes)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'upper': bbands.upper,
|
||||
'middle': bbands.middle,
|
||||
'lower': bbands.lower,
|
||||
},
|
||||
expected={
|
||||
'upper': expected[0],
|
||||
'middle': expected[1],
|
||||
'lower': expected[2],
|
||||
},
|
||||
initial_workspace={
|
||||
USEquityPricing.close: AdjustedArray(
|
||||
closes,
|
||||
np.full_like(closes, True, dtype=bool),
|
||||
{},
|
||||
np.nan,
|
||||
data=closes,
|
||||
mask=mask,
|
||||
adjustments={},
|
||||
missing_value=np.nan,
|
||||
),
|
||||
},
|
||||
mask_sid=mask_sid,
|
||||
)['f']
|
||||
|
||||
expected_upper, expected_middle, expected_lower = self.expected(
|
||||
window_length,
|
||||
k,
|
||||
closes,
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
assert_equal(result.upper, expected_upper)
|
||||
assert_equal(result.middle, expected_middle)
|
||||
assert_equal(result.lower, expected_lower)
|
||||
|
||||
def test_bollinger_bands_output_ordering(self):
|
||||
bbands = BollingerBands(window_length=5, k=2)
|
||||
lower, middle, upper = bbands
|
||||
@@ -185,7 +146,7 @@ class AroonTestCase(ZiplineTestCase):
|
||||
assert_equal(out, expected_out)
|
||||
|
||||
|
||||
class TestFastStochasticOscillator(WithTechnicalFactor, ZiplineTestCase):
|
||||
class TestFastStochasticOscillator(ZiplineTestCase):
|
||||
"""
|
||||
Test the Fast Stochastic Oscillator
|
||||
"""
|
||||
@@ -427,7 +388,7 @@ class TestLinearWeightedMovingAverage(ZiplineTestCase):
|
||||
assert_equal(out, np.array([30., 31., 32., 33., 34.]))
|
||||
|
||||
|
||||
class TestTrueRange(WithTechnicalFactor, ZiplineTestCase):
|
||||
class TestTrueRange(ZiplineTestCase):
|
||||
|
||||
def test_tr_basic(self):
|
||||
tr = TrueRange()
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import product
|
||||
from unittest import TestCase
|
||||
|
||||
from toolz import assoc
|
||||
import pandas as pd
|
||||
|
||||
from zipline.assets import Asset
|
||||
from zipline.errors import (
|
||||
@@ -24,7 +25,7 @@ from zipline.pipeline import (
|
||||
CustomFactor,
|
||||
Factor,
|
||||
Filter,
|
||||
TermGraph,
|
||||
ExecutionPlan,
|
||||
)
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
@@ -33,6 +34,7 @@ from zipline.pipeline.factors import RecarrayField
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.pipeline.term import AssetExists, Slice
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import WithTradingSessions, ZiplineTestCase
|
||||
from zipline.testing.predicates import (
|
||||
assert_equal,
|
||||
assert_raises,
|
||||
@@ -152,7 +154,14 @@ def to_dict(l):
|
||||
return dict(zip(map(str, range(len(l))), l))
|
||||
|
||||
|
||||
class DependencyResolutionTestCase(TestCase):
|
||||
class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
|
||||
|
||||
TRADING_CALENDAR_STRS = ('NYSE',)
|
||||
START_DATE = pd.Timestamp('2014-01-02', tz='UTC')
|
||||
END_DATE = pd.Timestamp('2014-12-31', tz='UTC')
|
||||
|
||||
execution_plan_start = pd.Timestamp('2014-06-01', tz='UTC')
|
||||
execution_plan_end = pd.Timestamp('2014-06-30', tz='UTC')
|
||||
|
||||
def check_dependency_order(self, ordered_terms):
|
||||
seen = set()
|
||||
@@ -163,6 +172,14 @@ class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
seen.add(term)
|
||||
|
||||
def make_execution_plan(self, terms):
|
||||
return ExecutionPlan(
|
||||
terms,
|
||||
self.nyse_sessions,
|
||||
self.execution_plan_start,
|
||||
self.execution_plan_end,
|
||||
)
|
||||
|
||||
def test_single_factor(self):
|
||||
"""
|
||||
Test dependency resolution for a single factor.
|
||||
@@ -182,7 +199,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
|
||||
|
||||
for foobar in gen_equivalent_factors():
|
||||
check_output(TermGraph(to_dict([foobar])))
|
||||
check_output(self.make_execution_plan(to_dict([foobar])))
|
||||
|
||||
def test_single_factor_instance_args(self):
|
||||
"""
|
||||
@@ -190,7 +207,9 @@ class DependencyResolutionTestCase(TestCase):
|
||||
the constructor.
|
||||
"""
|
||||
bar, buzz = SomeDataSet.bar, SomeDataSet.buzz
|
||||
graph = TermGraph(to_dict([SomeFactor([bar, buzz], window_length=5)]))
|
||||
|
||||
factor = SomeFactor([bar, buzz], window_length=5)
|
||||
graph = self.make_execution_plan(to_dict([factor]))
|
||||
|
||||
resolution_order = list(graph.ordered())
|
||||
|
||||
@@ -214,7 +233,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
|
||||
f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz])
|
||||
|
||||
graph = TermGraph(to_dict([f1, f2]))
|
||||
graph = self.make_execution_plan(to_dict([f1, f2]))
|
||||
resolution_order = list(graph.ordered())
|
||||
|
||||
# bar should only appear once.
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Tests for zipline/utils/pandas_utils.py
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from zipline.testing import parameter_space, ZiplineTestCase
|
||||
from zipline.utils.pandas_utils import nearest_unequal_elements
|
||||
|
||||
|
||||
class TestNearestUnequalElements(ZiplineTestCase):
|
||||
|
||||
@parameter_space(tz=['UTC', 'US/Eastern'], __fail_fast=True)
|
||||
def test_nearest_unequal_elements(self, tz):
|
||||
|
||||
dts = pd.to_datetime(
|
||||
['2014-01-01', '2014-01-05', '2014-01-06', '2014-01-09'],
|
||||
).tz_localize(tz)
|
||||
|
||||
t = lambda s: None if s is None else pd.Timestamp(s, tz=tz)
|
||||
|
||||
for dt, before, after in (('2013-12-30', None, '2014-01-01'),
|
||||
('2013-12-31', None, '2014-01-01'),
|
||||
('2014-01-01', None, '2014-01-05'),
|
||||
('2014-01-02', '2014-01-01', '2014-01-05'),
|
||||
('2014-01-03', '2014-01-01', '2014-01-05'),
|
||||
('2014-01-04', '2014-01-01', '2014-01-05'),
|
||||
('2014-01-05', '2014-01-01', '2014-01-06'),
|
||||
('2014-01-06', '2014-01-05', '2014-01-09'),
|
||||
('2014-01-07', '2014-01-06', '2014-01-09'),
|
||||
('2014-01-08', '2014-01-06', '2014-01-09'),
|
||||
('2014-01-09', '2014-01-06', None),
|
||||
('2014-01-10', '2014-01-09', None),
|
||||
('2014-01-11', '2014-01-09', None)):
|
||||
computed = nearest_unequal_elements(dts, t(dt))
|
||||
expected = (t(before), t(after))
|
||||
self.assertEqual(computed, expected)
|
||||
|
||||
@parameter_space(tz=['UTC', 'US/Eastern'], __fail_fast=True)
|
||||
def test_nearest_unequal_elements_short_dts(self, tz):
|
||||
|
||||
# Length 1.
|
||||
dts = pd.to_datetime(['2014-01-01']).tz_localize(tz)
|
||||
t = lambda s: None if s is None else pd.Timestamp(s, tz=tz)
|
||||
|
||||
for dt, before, after in (('2013-12-31', None, '2014-01-01'),
|
||||
('2014-01-01', None, None),
|
||||
('2014-01-02', '2014-01-01', None)):
|
||||
computed = nearest_unequal_elements(dts, t(dt))
|
||||
expected = (t(before), t(after))
|
||||
self.assertEqual(computed, expected)
|
||||
|
||||
# Length 0
|
||||
dts = pd.to_datetime([]).tz_localize(tz)
|
||||
for dt, before, after in (('2013-12-31', None, None),
|
||||
('2014-01-01', None, None),
|
||||
('2014-01-02', None, None)):
|
||||
computed = nearest_unequal_elements(dts, t(dt))
|
||||
expected = (t(before), t(after))
|
||||
self.assertEqual(computed, expected)
|
||||
|
||||
def test_nearest_unequal_bad_input(self):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
nearest_unequal_elements(
|
||||
pd.to_datetime(['2014', '2014']),
|
||||
pd.Timestamp('2014'),
|
||||
)
|
||||
|
||||
self.assertEqual(str(e.exception), 'dts must be unique')
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
nearest_unequal_elements(
|
||||
pd.to_datetime(['2014', '2013']),
|
||||
pd.Timestamp('2014'),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
str(e.exception),
|
||||
'dts must be sorted in increasing order',
|
||||
)
|
||||
@@ -262,7 +262,11 @@ class PreprocessTestCase(TestCase):
|
||||
expected_message = (
|
||||
"{qualname}() expected a value in {set_!r}"
|
||||
" for argument 'a', but got 'c' instead."
|
||||
).format(set_=set_, qualname=qualname(f))
|
||||
).format(
|
||||
# We special-case set to show a tuple instead of the set repr.
|
||||
set_=tuple(sorted(set_)),
|
||||
qualname=qualname(f),
|
||||
)
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_expect_dtypes(self):
|
||||
|
||||
@@ -59,6 +59,7 @@ from .asset_db_schema import (
|
||||
)
|
||||
from zipline.utils.control_flow import invert
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import as_column
|
||||
from zipline.utils.sqlite_utils import group_into_chunks
|
||||
|
||||
log = Logger('assets.py')
|
||||
@@ -1096,7 +1097,7 @@ class AssetFinder(object):
|
||||
self._asset_lifetimes = self._compute_asset_lifetimes()
|
||||
lifetimes = self._asset_lifetimes
|
||||
|
||||
raw_dates = dates.asi8[:, None]
|
||||
raw_dates = as_column(dates.asi8)
|
||||
if include_start_date:
|
||||
mask = lifetimes.start <= raw_dates
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from textwrap import dedent
|
||||
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
@@ -583,6 +584,29 @@ class NoFurtherDataError(ZiplineError):
|
||||
# that can be usefully templated.
|
||||
msg = '{msg}'
|
||||
|
||||
@classmethod
|
||||
def from_lookback_window(cls,
|
||||
initial_message,
|
||||
first_date,
|
||||
lookback_start,
|
||||
lookback_length):
|
||||
return cls(
|
||||
msg=dedent(
|
||||
"""
|
||||
{initial_message}
|
||||
|
||||
lookback window started at {lookback_start}
|
||||
earliest known date was {first_date}
|
||||
{lookback_length} extra rows of data were required
|
||||
"""
|
||||
).format(
|
||||
initial_message=initial_message,
|
||||
first_date=first_date,
|
||||
lookback_start=lookback_start,
|
||||
lookback_length=lookback_length,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedDatetimeFormat(ZiplineError):
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ from .engine import SimplePipelineEngine
|
||||
from .factors import Factor, CustomFactor
|
||||
from .filters import Filter, CustomFilter
|
||||
from .term import Term
|
||||
from .graph import TermGraph
|
||||
from .graph import ExecutionPlan, TermGraph
|
||||
from .pipeline import Pipeline
|
||||
from .loaders import USEquityPricingLoader
|
||||
|
||||
@@ -56,6 +56,7 @@ __all__ = (
|
||||
'CustomFilter',
|
||||
'CustomClassifier',
|
||||
'engine_from_files',
|
||||
'ExecutionPlan',
|
||||
'Factor',
|
||||
'Filter',
|
||||
'Pipeline',
|
||||
|
||||
@@ -14,6 +14,7 @@ from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.pipeline.term import ComputableTerm
|
||||
from zipline.utils.compat import unicode
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.memoize import classlazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
int64_dtype,
|
||||
@@ -23,6 +24,7 @@ from zipline.utils.numpy_utils import (
|
||||
from ..filters import ArrayPredicate, NotNullFilter, NullFilter, NumExprFilter
|
||||
from ..mixins import (
|
||||
CustomTermMixin,
|
||||
DownsampledMixin,
|
||||
LatestMixin,
|
||||
PositiveWindowLengthMixin,
|
||||
RestrictedDTypeMixin,
|
||||
@@ -301,6 +303,10 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
raise AssertionError("Expected a LabelArray, got %s." % type(data))
|
||||
return data.as_categorical()
|
||||
|
||||
@classlazyval
|
||||
def _downsampled_type(self):
|
||||
return DownsampledMixin.make_downsampled_type(Classifier)
|
||||
|
||||
|
||||
class Everything(Classifier):
|
||||
"""
|
||||
|
||||
@@ -112,7 +112,6 @@ class BoundColumn(LoadableTerm):
|
||||
The name of this column.
|
||||
"""
|
||||
mask = AssetExists()
|
||||
inputs = ()
|
||||
window_safe = True
|
||||
|
||||
def __new__(cls, dtype, missing_value, dataset, name):
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Helpers for downsampling code.
|
||||
"""
|
||||
from operator import attrgetter
|
||||
|
||||
from zipline.utils.input_validation import expect_element
|
||||
from zipline.utils.numpy_utils import changed_locations
|
||||
from zipline.utils.sharedoc import (
|
||||
templated_docstring,
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
|
||||
)
|
||||
|
||||
_dt_to_period = {
|
||||
'year_start': attrgetter('year'),
|
||||
'quarter_start': attrgetter('quarter'),
|
||||
'month_start': attrgetter('month'),
|
||||
'week_start': attrgetter('week'),
|
||||
}
|
||||
|
||||
SUPPORTED_DOWNSAMPLE_FREQUENCIES = frozenset(_dt_to_period)
|
||||
|
||||
|
||||
expect_downsample_frequency = expect_element(
|
||||
frequency=SUPPORTED_DOWNSAMPLE_FREQUENCIES,
|
||||
)
|
||||
|
||||
|
||||
@expect_downsample_frequency
|
||||
@templated_docstring(frequency=PIPELINE_DOWNSAMPLING_FREQUENCY_DOC)
|
||||
def select_sampling_indices(dates, frequency):
|
||||
"""
|
||||
Choose entries from ``dates`` to use for downsampling at ``frequency``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : pd.DatetimeIndex
|
||||
Dates from which to select sample choices.
|
||||
{frequency}
|
||||
|
||||
Returns
|
||||
-------
|
||||
indices : np.array[int64]
|
||||
An array condtaining indices of dates on which samples should be taken.
|
||||
|
||||
The resulting index will always include 0 as a sample index, and it
|
||||
will include the first date of each subsequent year/quarter/month/week,
|
||||
as determined by ``frequency``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function assumes that ``dates`` does not have large gaps.
|
||||
|
||||
In particular, it assumes that the maximum distance between any two entries
|
||||
in ``dates`` is never greater than a year, which we rely on because we use
|
||||
``np.diff(dates.<frequency>)`` to find dates where the sampling
|
||||
period has changed.
|
||||
"""
|
||||
return changed_locations(
|
||||
_dt_to_period[frequency](dates),
|
||||
include_first=True
|
||||
)
|
||||
+25
-11
@@ -18,10 +18,14 @@ from toolz.curried.operator import getitem
|
||||
|
||||
from zipline.lib.adjusted_array import ensure_adjusted_array, ensure_ndarray
|
||||
from zipline.errors import NoFurtherDataError
|
||||
from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis
|
||||
from zipline.utils.numpy_utils import (
|
||||
as_column,
|
||||
repeat_first_axis,
|
||||
repeat_last_axis,
|
||||
)
|
||||
from zipline.utils.pandas_utils import explode
|
||||
|
||||
from .term import AssetExists, LoadableTerm
|
||||
from .term import AssetExists, InputDates, LoadableTerm
|
||||
|
||||
|
||||
class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
@@ -98,6 +102,7 @@ class SimplePipelineEngine(object):
|
||||
'_calendar',
|
||||
'_finder',
|
||||
'_root_mask_term',
|
||||
'_root_mask_dates_term',
|
||||
'__weakref__',
|
||||
)
|
||||
|
||||
@@ -105,7 +110,9 @@ class SimplePipelineEngine(object):
|
||||
self._get_loader = get_loader
|
||||
self._calendar = calendar
|
||||
self._finder = asset_finder
|
||||
|
||||
self._root_mask_term = AssetExists()
|
||||
self._root_mask_dates_term = InputDates()
|
||||
|
||||
def run_pipeline(self, pipeline, start_date, end_date):
|
||||
"""
|
||||
@@ -161,7 +168,13 @@ class SimplePipelineEngine(object):
|
||||
)
|
||||
|
||||
screen_name = uuid4().hex
|
||||
graph = pipeline.to_graph(screen_name, self._root_mask_term)
|
||||
graph = pipeline.to_execution_plan(
|
||||
screen_name,
|
||||
self._root_mask_term,
|
||||
self._calendar,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
extra_rows = graph.extra_rows[self._root_mask_term]
|
||||
root_mask = self._compute_root_mask(start_date, end_date, extra_rows)
|
||||
dates, assets, root_mask_values = explode(root_mask)
|
||||
@@ -170,7 +183,10 @@ class SimplePipelineEngine(object):
|
||||
graph,
|
||||
dates,
|
||||
assets,
|
||||
initial_workspace={self._root_mask_term: root_mask_values},
|
||||
initial_workspace={
|
||||
self._root_mask_term: root_mask_values,
|
||||
self._root_mask_dates_term: as_column(dates.values)
|
||||
},
|
||||
)
|
||||
|
||||
return self._to_narrow(
|
||||
@@ -210,13 +226,11 @@ class SimplePipelineEngine(object):
|
||||
finder = self._finder
|
||||
start_idx, end_idx = self._calendar.slice_locs(start_date, end_date)
|
||||
if start_idx < extra_rows:
|
||||
raise NoFurtherDataError(
|
||||
msg="Insufficient data to compute Pipeline mask: "
|
||||
"start date was %s, "
|
||||
"earliest known date was %s, "
|
||||
"and %d extra rows were requested." % (
|
||||
start_date, calendar[0], extra_rows,
|
||||
),
|
||||
raise NoFurtherDataError.from_lookback_window(
|
||||
initial_message="Insufficient data to compute Pipeline:",
|
||||
first_date=calendar[0],
|
||||
lookback_start=start_date,
|
||||
lookback_length=extra_rows,
|
||||
)
|
||||
|
||||
# Build lifetimes matrix reaching back to `extra_rows` days before
|
||||
|
||||
@@ -33,6 +33,7 @@ from zipline.pipeline.filters import (
|
||||
)
|
||||
from zipline.pipeline.mixins import (
|
||||
CustomTermMixin,
|
||||
DownsampledMixin,
|
||||
LatestMixin,
|
||||
PositiveWindowLengthMixin,
|
||||
RestrictedDTypeMixin,
|
||||
@@ -43,6 +44,7 @@ from zipline.pipeline.term import ComputableTerm, Term
|
||||
from zipline.utils.functional import with_doc, with_name
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.math_utils import nanmean, nanstd
|
||||
from zipline.utils.memoize import classlazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
categorical_dtype,
|
||||
@@ -1071,6 +1073,10 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
"""
|
||||
return (-inf < self) & (self < inf)
|
||||
|
||||
@classlazyval
|
||||
def _downsampled_type(self):
|
||||
return DownsampledMixin.make_downsampled_type(Factor)
|
||||
|
||||
|
||||
class NumExprFactor(NumericalExpression, Factor):
|
||||
"""
|
||||
@@ -1468,7 +1474,9 @@ class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
|
||||
|
||||
|
||||
class RecarrayField(SingleInputMixin, Factor):
|
||||
|
||||
"""
|
||||
A single field from a multi-output factor.
|
||||
"""
|
||||
def __new__(cls, factor, attribute):
|
||||
return super(RecarrayField, cls).__new__(
|
||||
cls,
|
||||
|
||||
@@ -25,6 +25,7 @@ from zipline.pipeline.expression import (
|
||||
)
|
||||
from zipline.pipeline.mixins import (
|
||||
CustomTermMixin,
|
||||
DownsampledMixin,
|
||||
LatestMixin,
|
||||
PositiveWindowLengthMixin,
|
||||
RestrictedDTypeMixin,
|
||||
@@ -32,6 +33,7 @@ from zipline.pipeline.mixins import (
|
||||
)
|
||||
from zipline.pipeline.term import ComputableTerm, Term
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.memoize import classlazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype, repeat_first_axis
|
||||
|
||||
|
||||
@@ -201,6 +203,10 @@ class Filter(RestrictedDTypeMixin, ComputableTerm):
|
||||
)
|
||||
return retval
|
||||
|
||||
@classlazyval
|
||||
def _downsampled_type(self):
|
||||
return DownsampledMixin.make_downsampled_type(Filter)
|
||||
|
||||
|
||||
class NumExprFilter(NumericalExpression, Filter):
|
||||
"""
|
||||
|
||||
+163
-79
@@ -5,7 +5,7 @@ from networkx import (
|
||||
DiGraph,
|
||||
topological_sort,
|
||||
)
|
||||
from six import itervalues, iteritems
|
||||
from six import iteritems, itervalues
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.pipeline.visualize import display_graph
|
||||
|
||||
@@ -18,7 +18,112 @@ class CyclicDependency(Exception):
|
||||
|
||||
class TermGraph(DiGraph):
|
||||
"""
|
||||
Graph represention of Pipeline Term dependencies.
|
||||
An abstract representation of Pipeline Term dependencies.
|
||||
|
||||
This class does not keep any additional metadata about any term relations
|
||||
other than dependency ordering. As such it is only useful in contexts
|
||||
where you care exclusively about order properties (for example, when
|
||||
drawing visualizations of execution order).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
terms : dict
|
||||
A dict mapping names to final output terms.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
outputs
|
||||
|
||||
Methods
|
||||
-------
|
||||
ordered()
|
||||
Return a topologically-sorted iterator over the terms in self.
|
||||
|
||||
See Also
|
||||
--------
|
||||
ExecutionPlan
|
||||
"""
|
||||
def __init__(self, terms):
|
||||
super(TermGraph, self).__init__()
|
||||
|
||||
self._frozen = False
|
||||
parents = set()
|
||||
for term in itervalues(terms):
|
||||
self._add_to_graph(term, parents)
|
||||
# No parents should be left between top-level terms.
|
||||
assert not parents
|
||||
|
||||
self._outputs = terms
|
||||
self._ordered = topological_sort(self)
|
||||
|
||||
# Mark that no more terms should be added to the graph.
|
||||
self._frozen = True
|
||||
|
||||
def _add_to_graph(self, term, parents):
|
||||
"""
|
||||
Add a term and all its children to ``graph``.
|
||||
|
||||
``parents`` is the set of all the parents of ``term` that we've added
|
||||
so far. It is only used to detect dependency cycles.
|
||||
"""
|
||||
if self._frozen:
|
||||
raise ValueError(
|
||||
"Can't mutate %s after construction." % type(self).__name__
|
||||
)
|
||||
|
||||
# If we've seen this node already as a parent of the current traversal,
|
||||
# it means we have an unsatisifiable dependency. This should only be
|
||||
# possible if the term's inputs are mutated after construction.
|
||||
if term in parents:
|
||||
raise CyclicDependency(term)
|
||||
|
||||
parents.add(term)
|
||||
|
||||
self.add_node(term)
|
||||
|
||||
for dependency in term.dependencies:
|
||||
self._add_to_graph(dependency, parents)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
"""
|
||||
Dict mapping names to designated output terms.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
def ordered(self):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in `self`.
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
|
||||
@lazyval
|
||||
def jpeg(self):
|
||||
return display_graph(self, 'jpeg')
|
||||
|
||||
@lazyval
|
||||
def png(self):
|
||||
return display_graph(self, 'png')
|
||||
|
||||
@lazyval
|
||||
def svg(self):
|
||||
return display_graph(self, 'svg')
|
||||
|
||||
def _repr_png_(self):
|
||||
return self.png.data
|
||||
|
||||
|
||||
class ExecutionPlan(TermGraph):
|
||||
"""
|
||||
Graph represention of Pipeline Term dependencies that includes metadata
|
||||
about extra rows required to perform computations.
|
||||
|
||||
Each node in the graph has an `extra_rows` attribute, indicating how many,
|
||||
if any, extra rows we should compute for the node. Extra rows are most
|
||||
@@ -30,6 +135,13 @@ class TermGraph(DiGraph):
|
||||
----------
|
||||
terms : dict
|
||||
A dict mapping names to final output terms.
|
||||
all_dates : pd.DatetimeIndex
|
||||
An index of all known trading days for which ``terms`` will be
|
||||
computed.
|
||||
start_date : pd.Timestamp
|
||||
The first date for which output is requested for ``terms``.
|
||||
end_date : pd.Timestamp
|
||||
The last date for which output is requested for ``terms``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -42,21 +154,58 @@ class TermGraph(DiGraph):
|
||||
ordered()
|
||||
Return a topologically-sorted iterator over the terms in self.
|
||||
"""
|
||||
def __init__(self, terms):
|
||||
super(TermGraph, self).__init__(self)
|
||||
def __init__(self,
|
||||
terms,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=0):
|
||||
super(ExecutionPlan, self).__init__(terms)
|
||||
|
||||
self._frozen = False
|
||||
parents = set()
|
||||
for term in itervalues(terms):
|
||||
self._add_to_graph(term, parents, extra_rows=0)
|
||||
# No parents should be left between top-level terms.
|
||||
assert not parents
|
||||
for term in terms.values():
|
||||
self.set_extra_rows(
|
||||
term,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=min_extra_rows,
|
||||
)
|
||||
|
||||
self._outputs = terms
|
||||
self._ordered = topological_sort(self)
|
||||
def set_extra_rows(self,
|
||||
term,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows):
|
||||
"""
|
||||
Compute ``extra_rows`` for transitive dependencies of ``root_terms``
|
||||
"""
|
||||
# A term can require that additional extra rows beyond the minimum be
|
||||
# computed. This is most often used with downsampled terms, which need
|
||||
# to ensure that the first date is a computation date.
|
||||
extra_rows_for_term = term.compute_extra_rows(
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows,
|
||||
)
|
||||
if extra_rows_for_term < min_extra_rows:
|
||||
raise ValueError(
|
||||
"term %s requested fewer rows than the minimum of %d" % (
|
||||
term, min_extra_rows,
|
||||
)
|
||||
)
|
||||
|
||||
# Mark that no more terms should be added to the graph.
|
||||
self._frozen = True
|
||||
self._ensure_extra_rows(term, extra_rows_for_term)
|
||||
|
||||
for dependency, additional_extra_rows in term.dependencies.items():
|
||||
self.set_extra_rows(
|
||||
dependency,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=extra_rows_for_term + additional_extra_rows,
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def offset(self):
|
||||
@@ -138,9 +287,6 @@ class TermGraph(DiGraph):
|
||||
"""
|
||||
A dict mapping `term` -> `# of extra rows to load/compute of `term`.
|
||||
|
||||
This is always the maximum number of extra **input** rows required by
|
||||
any Filter/Factor for which `term` is an input.
|
||||
|
||||
Notes
|
||||
----
|
||||
This value depends on the other terms in the graph that require `term`
|
||||
@@ -175,71 +321,9 @@ class TermGraph(DiGraph):
|
||||
for term, attrs in iteritems(self.node)
|
||||
}
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
"""
|
||||
Dict mapping names to designated output terms.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
def ordered(self):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in `self`.
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
|
||||
def _add_to_graph(self, term, parents, extra_rows):
|
||||
"""
|
||||
Add `term` and all its inputs to the graph.
|
||||
"""
|
||||
if self._frozen:
|
||||
raise ValueError("Can't mutate `TermGraph` after construction.")
|
||||
# If we've seen this node already as a parent of the current traversal,
|
||||
# it means we have an unsatisifiable dependency. This should only be
|
||||
# possible if the term's inputs are mutated after construction.
|
||||
if term in parents:
|
||||
raise CyclicDependency(term)
|
||||
parents.add(term)
|
||||
|
||||
# Idempotent if term is already in the graph.
|
||||
self.add_node(term)
|
||||
|
||||
# Make sure we're going to compute at least `extra_rows` of `term`.
|
||||
self._ensure_extra_rows(term, extra_rows)
|
||||
|
||||
# Recursively add dependencies.
|
||||
for dependency, additional_extra_rows in term.dependencies.items():
|
||||
self._add_to_graph(
|
||||
dependency,
|
||||
parents,
|
||||
extra_rows=extra_rows + additional_extra_rows,
|
||||
)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
def _ensure_extra_rows(self, term, N):
|
||||
"""
|
||||
Ensure that we're going to compute at least N extra rows of `term`.
|
||||
"""
|
||||
attrs = self.node[term]
|
||||
attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0))
|
||||
|
||||
@lazyval
|
||||
def jpeg(self):
|
||||
return display_graph(self, 'jpeg')
|
||||
|
||||
@lazyval
|
||||
def png(self):
|
||||
return display_graph(self, 'png')
|
||||
|
||||
@lazyval
|
||||
def svg(self):
|
||||
return display_graph(self, 'svg')
|
||||
|
||||
def _repr_png_(self):
|
||||
return self.png.data
|
||||
|
||||
@@ -15,6 +15,7 @@ from pandas import (
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import make_adjustment_from_labels
|
||||
from zipline.utils.numpy_utils import as_column
|
||||
from zipline.utils.pandas_utils import sort_values
|
||||
from .base import PipelineLoader
|
||||
|
||||
@@ -169,7 +170,7 @@ class DataFrameLoader(PipelineLoader):
|
||||
# Pull out requested columns/rows from our baseline data.
|
||||
data=self.baseline[ix_(date_indexer, assets_indexer)],
|
||||
# Mask out requested columns/rows that didnt match.
|
||||
mask=(good_assets & good_dates[:, None]) & mask,
|
||||
mask=(good_assets & as_column(good_dates)) & mask,
|
||||
adjustments=self.format_adjustments(dates, assets),
|
||||
missing_value=column.missing_value,
|
||||
),
|
||||
|
||||
+243
-1
@@ -1,16 +1,36 @@
|
||||
"""
|
||||
Mixins classes for use with Filters and Factors.
|
||||
"""
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy import (
|
||||
array,
|
||||
full,
|
||||
recarray,
|
||||
vstack,
|
||||
)
|
||||
from pandas import NaT as pd_NaT
|
||||
|
||||
from zipline.errors import (
|
||||
WindowLengthNotPositive,
|
||||
UnsupportedDataType,
|
||||
NoFurtherDataError,
|
||||
)
|
||||
from zipline.utils.control_flow import nullctx
|
||||
from zipline.errors import WindowLengthNotPositive, UnsupportedDataType
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.sharedoc import (
|
||||
format_docstring,
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
|
||||
)
|
||||
from zipline.utils.pandas_utils import nearest_unequal_elements
|
||||
|
||||
|
||||
from .downsample_helpers import (
|
||||
select_sampling_indices,
|
||||
expect_downsample_frequency,
|
||||
)
|
||||
from .sentinels import NotSpecified
|
||||
from .term import Term
|
||||
|
||||
|
||||
class PositiveWindowLengthMixin(object):
|
||||
@@ -218,3 +238,225 @@ class LatestMixin(SingleInputMixin):
|
||||
actual=self.inputs[0].dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DownsampledMixin(StandardOutputs):
|
||||
"""
|
||||
Mixin for behavior shared by Downsampled{Factor,Filter,Classifier}
|
||||
|
||||
A downsampled term is a wrapper around the "real" term that performs actual
|
||||
computation. The downsampler is responsible for calling the real term's
|
||||
`compute` method at selected intervals and forward-filling the computed
|
||||
values.
|
||||
|
||||
Downsampling is not currently supported for terms with multiple outputs.
|
||||
"""
|
||||
# There's no reason to take a window of a downsampled term. The whole
|
||||
# point is that you're re-using the same result multiple times.
|
||||
window_safe = False
|
||||
|
||||
@expect_types(term=Term)
|
||||
@expect_downsample_frequency
|
||||
def __new__(cls, term, frequency):
|
||||
return super(DownsampledMixin, cls).__new__(
|
||||
cls,
|
||||
inputs=term.inputs,
|
||||
outputs=term.outputs,
|
||||
window_length=term.window_length,
|
||||
mask=term.mask,
|
||||
frequency=frequency,
|
||||
wrapped_term=term,
|
||||
dtype=term.dtype,
|
||||
missing_value=term.missing_value,
|
||||
ndim=term.ndim,
|
||||
)
|
||||
|
||||
def _init(self, frequency, wrapped_term, *args, **kwargs):
|
||||
self._frequency = frequency
|
||||
self._wrapped_term = wrapped_term
|
||||
return super(DownsampledMixin, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _static_identity(cls, frequency, wrapped_term, *args, **kwargs):
|
||||
return (
|
||||
super(DownsampledMixin, cls)._static_identity(*args, **kwargs),
|
||||
frequency,
|
||||
wrapped_term,
|
||||
)
|
||||
|
||||
def compute_extra_rows(self,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows):
|
||||
"""
|
||||
Ensure that min_extra_rows pushes us back to a computation date.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
all_dates : pd.DatetimeIndex
|
||||
The trading sessions against which ``self`` will be computed.
|
||||
start_date : pd.Timestamp
|
||||
The first date for which final output is requested.
|
||||
end_date : pd.Timestamp
|
||||
The last date for which final output is requested.
|
||||
min_extra_rows : int
|
||||
The minimum number of extra rows required of ``self``, as
|
||||
determined by other terms that depend on ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
extra_rows : int
|
||||
The number of extra rows to compute. This will be the minimum
|
||||
number of rows required to make our computed start_date fall on a
|
||||
recomputation date.
|
||||
"""
|
||||
try:
|
||||
current_start_pos = all_dates.get_loc(start_date) - min_extra_rows
|
||||
if current_start_pos < 0:
|
||||
raise NoFurtherDataError(
|
||||
initial_message="Insufficient data to compute Pipeline:",
|
||||
first_date=all_dates[0],
|
||||
lookback_start=start_date,
|
||||
lookback_length=min_extra_rows,
|
||||
)
|
||||
except KeyError:
|
||||
before, after = nearest_unequal_elements(all_dates, start_date)
|
||||
raise ValueError(
|
||||
"Pipeline start_date {start_date} is not in calendar.\n"
|
||||
"Latest date before start_date is {before}.\n"
|
||||
"Earliest date after start_date is {after}.".format(
|
||||
start_date=start_date,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
)
|
||||
|
||||
# Our possible target dates are all the dates on or before the current
|
||||
# starting position.
|
||||
# TODO: Consider bounding this below by self.window_length
|
||||
candidates = all_dates[:current_start_pos + 1]
|
||||
|
||||
# Choose the latest date in the candidates that is the start of a new
|
||||
# period at our frequency.
|
||||
choices = select_sampling_indices(candidates, self._frequency)
|
||||
|
||||
# If we have choices, the last choice is the first date if the
|
||||
# period containing current_start_date. Choose it.
|
||||
new_start_date = candidates[choices[-1]]
|
||||
|
||||
# Add the difference between the new and old start dates to get the
|
||||
# number of rows for the new start_date.
|
||||
new_start_pos = all_dates.get_loc(new_start_date)
|
||||
assert new_start_pos <= current_start_pos, \
|
||||
"Computed negative extra rows!"
|
||||
|
||||
return min_extra_rows + (current_start_pos - new_start_pos)
|
||||
|
||||
def _compute(self, inputs, dates, assets, mask):
|
||||
"""
|
||||
Compute by delegating to self._wrapped_term._compute on sample dates.
|
||||
|
||||
On non-sample dates, forward-fill from previously-computed samples.
|
||||
"""
|
||||
to_sample = dates[select_sampling_indices(dates, self._frequency)]
|
||||
assert to_sample[0] == dates[0], \
|
||||
"Misaligned sampling dates in %s." % type(self).__name__
|
||||
|
||||
real_compute = self._wrapped_term._compute
|
||||
|
||||
# Inputs will contain different kinds of values depending on whether or
|
||||
# not we're a windowed computation.
|
||||
|
||||
# If we're windowed, then `inputs` is a list of iterators of ndarrays.
|
||||
# If we're not windowed, then `inputs` is just a list of ndarrays.
|
||||
# There are two things we care about doing with the input:
|
||||
# 1. Preparing an input to be passed to our wrapped term.
|
||||
# 2. Skipping an input if we're going to use an already-computed row.
|
||||
# We perform these actions differently based on the expected kind of
|
||||
# input, and we encapsulate these actions with closures so that we
|
||||
# don't clutter the code below with lots of branching.
|
||||
if self.windowed:
|
||||
# If we're windowed, inputs are stateful AdjustedArrays. We don't
|
||||
# need to do any preparation before forwarding to real_compute, but
|
||||
# we need to call `next` on them if we want to skip an iteration.
|
||||
def prepare_inputs():
|
||||
return inputs
|
||||
|
||||
def skip_this_input():
|
||||
for w in inputs:
|
||||
next(w)
|
||||
else:
|
||||
# If we're not windowed, inputs are just ndarrays. We need to
|
||||
# slice out a single row when forwarding to real_compute, but we
|
||||
# don't need to do anything to skip an input.
|
||||
def prepare_inputs():
|
||||
# i is the loop iteration variable below.
|
||||
return [a[[i]] for a in inputs]
|
||||
|
||||
def skip_this_input():
|
||||
pass
|
||||
|
||||
results = []
|
||||
samples = iter(to_sample)
|
||||
next_sample = next(samples)
|
||||
for i, compute_date in enumerate(dates):
|
||||
if next_sample == compute_date:
|
||||
results.append(
|
||||
real_compute(
|
||||
prepare_inputs(),
|
||||
dates[i:i + 1],
|
||||
assets,
|
||||
mask[i:i + 1],
|
||||
)
|
||||
)
|
||||
try:
|
||||
next_sample = next(samples)
|
||||
except StopIteration:
|
||||
# No more samples to take. Set next_sample to Nat, which
|
||||
# compares False with any other datetime.
|
||||
next_sample = pd_NaT
|
||||
else:
|
||||
skip_this_input()
|
||||
# Copy results from previous sample period.
|
||||
results.append(results[-1])
|
||||
|
||||
# We should have exhausted our sample dates.
|
||||
try:
|
||||
next_sample = next(samples)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Unconsumed sample date: %s" % next_sample)
|
||||
|
||||
# Concatenate stored results.
|
||||
return vstack(results)
|
||||
|
||||
@classmethod
|
||||
def make_downsampled_type(cls, other_base):
|
||||
"""
|
||||
Factory for making Downsampled{Filter,Factor,Classifier}.
|
||||
"""
|
||||
docstring = dedent(
|
||||
"""
|
||||
A {t} that defers to another {t} at lower-than-daily frequency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term : {t}
|
||||
{{frequency}}
|
||||
"""
|
||||
).format(t=other_base.__name__)
|
||||
|
||||
doc = format_docstring(
|
||||
owner_name=other_base.__name__,
|
||||
docstring=docstring,
|
||||
formatters={'frequency': PIPELINE_DOWNSAMPLING_FREQUENCY_DOC},
|
||||
)
|
||||
|
||||
return type(
|
||||
'Downsampled' + other_base.__name__,
|
||||
(cls, other_base,),
|
||||
{'__doc__': doc,
|
||||
'__module__': other_base.__module__},
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
|
||||
from zipline.errors import UnsupportedPipelineOutput
|
||||
from zipline.utils.input_validation import expect_types, optional
|
||||
from zipline.utils.input_validation import (
|
||||
expect_element,
|
||||
expect_types,
|
||||
optional,
|
||||
)
|
||||
|
||||
from .term import AssetExists, ComputableTerm, Term
|
||||
from .graph import ExecutionPlan, TermGraph
|
||||
from .filters import Filter
|
||||
from .graph import TermGraph
|
||||
from .term import AssetExists, ComputableTerm, Term
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
@@ -148,9 +152,39 @@ class Pipeline(object):
|
||||
)
|
||||
self._screen = screen
|
||||
|
||||
def to_graph(self, screen_name, default_screen):
|
||||
def to_execution_plan(self,
|
||||
screen_name,
|
||||
default_screen,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date):
|
||||
"""
|
||||
Compile into a TermGraph.
|
||||
Compile into an ExecutionPlan.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
screen_name : str
|
||||
Name to supply for self.screen.
|
||||
default_screen : zipline.pipeline.term.Term
|
||||
Term to use as a screen if self.screen is None.
|
||||
all_dates : pd.DatetimeIndex
|
||||
A calendar of dates to use to calculate starts and ends for each
|
||||
term.
|
||||
start_date : pd.Timestamp
|
||||
The first date of requested output.
|
||||
end_date : pd.Timestamp
|
||||
The last date of requested output.
|
||||
"""
|
||||
return ExecutionPlan(
|
||||
self._prepare_graph_terms(screen_name, default_screen),
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
def to_simple_graph(self, screen_name, default_screen):
|
||||
"""
|
||||
Compile into a simple TermGraph with no extra row metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -159,14 +193,20 @@ class Pipeline(object):
|
||||
default_screen : zipline.pipeline.term.Term
|
||||
Term to use as a screen if self.screen is None.
|
||||
"""
|
||||
return TermGraph(
|
||||
self._prepare_graph_terms(screen_name, default_screen)
|
||||
)
|
||||
|
||||
def _prepare_graph_terms(self, screen_name, default_screen):
|
||||
"""Helper for to_graph and to_execution_plan."""
|
||||
columns = self.columns.copy()
|
||||
screen = self.screen
|
||||
if screen is None:
|
||||
screen = default_screen
|
||||
columns[screen_name] = screen
|
||||
return columns
|
||||
|
||||
return TermGraph(columns)
|
||||
|
||||
@expect_element(format=('svg', 'png', 'jpeg'))
|
||||
def show_graph(self, format='svg'):
|
||||
"""
|
||||
Render this Pipeline as a DAG.
|
||||
@@ -176,7 +216,7 @@ class Pipeline(object):
|
||||
format : {'svg', 'png', 'jpeg'}
|
||||
Image format to render with. Default is 'svg'.
|
||||
"""
|
||||
g = self.to_graph('', AssetExists())
|
||||
g = self.to_simple_graph('', AssetExists())
|
||||
if format == 'svg':
|
||||
return g.svg
|
||||
elif format == 'png':
|
||||
@@ -184,7 +224,9 @@ class Pipeline(object):
|
||||
elif format == 'jpeg':
|
||||
return g.jpeg
|
||||
else:
|
||||
raise ValueError("Unknown graph format %r." % format)
|
||||
# We should never get here because of the expect_element decorator
|
||||
# above.
|
||||
raise AssertionError("Unknown graph format %r." % format)
|
||||
|
||||
@staticmethod
|
||||
def validate_column(column_name, term):
|
||||
|
||||
+104
-3
@@ -34,10 +34,15 @@ from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
categorical_dtype,
|
||||
datetime64ns_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
)
|
||||
from zipline.utils.sharedoc import (
|
||||
templated_docstring,
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
|
||||
)
|
||||
|
||||
from .mixins import SingleInputMixin
|
||||
from .downsample_helpers import expect_downsample_frequency
|
||||
from .sentinels import NotSpecified
|
||||
|
||||
|
||||
@@ -279,6 +284,39 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
# call super().
|
||||
self._subclass_called_super_validate = True
|
||||
|
||||
def compute_extra_rows(self,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows):
|
||||
"""
|
||||
Calculate the number of extra rows needed to compute ``self``.
|
||||
|
||||
Must return at least ``min_extra_rows``, and the default implementation
|
||||
is to just return ``min_extra_rows``. This is overridden by
|
||||
downsampled terms to ensure that the first date computed is a
|
||||
recomputation date.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
all_dates : pd.DatetimeIndex
|
||||
The trading sessions against which ``self`` will be computed.
|
||||
start_date : pd.Timestamp
|
||||
The first date for which final output is requested.
|
||||
end_date : pd.Timestamp
|
||||
The last date for which final output is requested.
|
||||
min_extra_rows : int
|
||||
The minimum number of extra rows required of ``self``, as
|
||||
determined by other terms that depend on ``self``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
extra_rows : int
|
||||
The number of extra rows to compute. Must be at least
|
||||
``min_extra_rows``.
|
||||
"""
|
||||
return min_extra_rows
|
||||
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
"""
|
||||
@@ -320,6 +358,9 @@ class AssetExists(Term):
|
||||
every asset on every date. We don't subclass Filter, however, because
|
||||
`AssetExists` is computed directly by the PipelineEngine.
|
||||
|
||||
This term is guaranteed to be available as an input for any term computed
|
||||
by SimplePipelineEngine.run_pipeline().
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
@@ -334,6 +375,38 @@ class AssetExists(Term):
|
||||
def __repr__(self):
|
||||
return "AssetExists()"
|
||||
|
||||
def _compute(self, today, assets, out):
|
||||
raise NotImplementedError(
|
||||
"AssetExists cannot be computed directly."
|
||||
" Check your PipelineEngine configuration."
|
||||
)
|
||||
|
||||
|
||||
class InputDates(Term):
|
||||
"""
|
||||
1-Dimensional term providing date labels for other term inputs.
|
||||
|
||||
This term is guaranteed to be available as an input for any term computed
|
||||
by SimplePipelineEngine.run_pipeline().
|
||||
"""
|
||||
ndim = 1
|
||||
dataset = None
|
||||
dtype = datetime64ns_dtype
|
||||
inputs = ()
|
||||
dependencies = {}
|
||||
mask = None
|
||||
windowed = False
|
||||
window_safe = True
|
||||
|
||||
def __repr__(self):
|
||||
return "InputDates()"
|
||||
|
||||
def _compute(self, today, assets, out):
|
||||
raise NotImplementedError(
|
||||
"InputDates cannot be computed directly."
|
||||
" Check your PipelineEngine configuration."
|
||||
)
|
||||
|
||||
|
||||
class LoadableTerm(Term):
|
||||
"""
|
||||
@@ -342,6 +415,7 @@ class LoadableTerm(Term):
|
||||
This is the base class for :class:`zipline.pipeline.data.BoundColumn`.
|
||||
"""
|
||||
windowed = False
|
||||
inputs = ()
|
||||
|
||||
@lazyval
|
||||
def dependencies(self):
|
||||
@@ -353,7 +427,7 @@ class ComputableTerm(Term):
|
||||
A Term that should be computed from a tuple of inputs.
|
||||
|
||||
This is the base class for :class:`zipline.pipeline.Factor`,
|
||||
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Factor`.
|
||||
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Classifier`.
|
||||
"""
|
||||
inputs = NotSpecified
|
||||
outputs = NotSpecified
|
||||
@@ -516,6 +590,27 @@ class ComputableTerm(Term):
|
||||
"""
|
||||
return data
|
||||
|
||||
def _downsampled_type(self):
|
||||
"""
|
||||
The expression type to return from self.downsample().
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"downsampling is not yet implemented "
|
||||
"for instances of %s." % type(self).__name__
|
||||
)
|
||||
|
||||
@expect_downsample_frequency
|
||||
@templated_docstring(frequency=PIPELINE_DOWNSAMPLING_FREQUENCY_DOC)
|
||||
def downsample(self, frequency):
|
||||
"""
|
||||
Make a term that computes from ``self`` at lower-than-daily frequency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
{frequency}
|
||||
"""
|
||||
return self._downsampled_type(term=self, frequency=frequency)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{type}({inputs}, window_length={window_length})"
|
||||
@@ -526,7 +621,7 @@ class ComputableTerm(Term):
|
||||
)
|
||||
|
||||
|
||||
class Slice(ComputableTerm, SingleInputMixin):
|
||||
class Slice(ComputableTerm):
|
||||
"""
|
||||
Term for extracting a single column of a another term's output.
|
||||
|
||||
@@ -582,6 +677,12 @@ class Slice(ComputableTerm, SingleInputMixin):
|
||||
# column.
|
||||
return windows[0][:, [asset_column]]
|
||||
|
||||
@property
|
||||
def _downsampled_type(self):
|
||||
raise NotImplementedError(
|
||||
'downsampling of slices is not yet supported'
|
||||
)
|
||||
|
||||
|
||||
def validate_dtype(termname, dtype, missing_value):
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,6 @@ from .core import ( # noqa
|
||||
make_test_handler,
|
||||
make_trade_data_for_asset_info,
|
||||
parameter_space,
|
||||
parameter_space,
|
||||
patch_os_environment,
|
||||
patch_read_csv,
|
||||
permute_rows,
|
||||
|
||||
@@ -47,9 +47,11 @@ from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.sentinel import sentinel
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.numpy_utils import as_column
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
import numpy as np
|
||||
from numpy import float64
|
||||
|
||||
@@ -308,11 +310,11 @@ def make_trade_data_for_asset_info(dates,
|
||||
price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
|
||||
price_date_deltas = (np.arange(len(dates), dtype=float64) *
|
||||
price_step_by_date)
|
||||
prices = (price_sid_deltas + price_date_deltas[:, None]) + price_start
|
||||
prices = (price_sid_deltas + as_column(price_date_deltas)) + price_start
|
||||
|
||||
volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
|
||||
volume_date_deltas = np.arange(len(dates)) * volume_step_by_date
|
||||
volumes = (volume_sid_deltas + volume_date_deltas[:, None]) + volume_start
|
||||
volumes = volume_sid_deltas + as_column(volume_date_deltas) + volume_start
|
||||
|
||||
for j, sid in enumerate(sids):
|
||||
start_date, end_date = asset_info.loc[sid, ['start_date', 'end_date']]
|
||||
|
||||
@@ -235,6 +235,9 @@ class WithDefaultDateBounds(object):
|
||||
ZiplineTestCase mixin which makes it possible to synchronize date bounds
|
||||
across fixtures.
|
||||
|
||||
This fixture should always be the last fixture in bases of any fixture or
|
||||
test case that uses it.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
START_DATE : datetime
|
||||
@@ -420,7 +423,9 @@ class WithTradingCalendars(object):
|
||||
cls.trading_calendars[exchange] = get_calendar(cal_str)
|
||||
|
||||
|
||||
class WithTradingEnvironment(WithAssetFinder, WithTradingCalendars):
|
||||
class WithTradingEnvironment(WithAssetFinder,
|
||||
WithTradingCalendars,
|
||||
WithDefaultDateBounds):
|
||||
"""
|
||||
ZiplineTestCase mixin providing cls.env as a class-level fixture.
|
||||
|
||||
@@ -527,7 +532,7 @@ class WithSimParams(WithTradingEnvironment):
|
||||
cls.sim_params = cls.make_simparams()
|
||||
|
||||
|
||||
class WithTradingSessions(WithTradingCalendars):
|
||||
class WithTradingSessions(WithTradingCalendars, WithDefaultDateBounds):
|
||||
"""
|
||||
ZiplineTestCase mixin providing cls.trading_days, cls.all_trading_sessions
|
||||
as a class-level fixture.
|
||||
@@ -1212,7 +1217,7 @@ class WithSeededRandomPipelineEngine(WithTradingSessions, WithAssetFinder):
|
||||
if start_date not in self.trading_days:
|
||||
raise AssertionError("Start date not in calendar: %s" % start_date)
|
||||
if end_date not in self.trading_days:
|
||||
raise AssertionError("Start date not in calendar: %s" % start_date)
|
||||
raise AssertionError("End date not in calendar: %s" % end_date)
|
||||
return self.seeded_random_engine.run_pipeline(
|
||||
pipeline,
|
||||
start_date,
|
||||
|
||||
@@ -483,10 +483,17 @@ def expect_element(*_pos, **named):
|
||||
raise TypeError("expect_element() only takes keyword arguments.")
|
||||
|
||||
def _expect_element(collection):
|
||||
if isinstance(collection, (set, frozenset)):
|
||||
# Special case the error message for set and frozen set to make it
|
||||
# less verbose.
|
||||
collection_for_error_message = tuple(sorted(collection))
|
||||
else:
|
||||
collection_for_error_message = collection
|
||||
|
||||
template = (
|
||||
"%(funcname)s() expected a value in {collection} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
).format(collection=collection)
|
||||
).format(collection=collection_for_error_message)
|
||||
return make_check(
|
||||
ValueError,
|
||||
template,
|
||||
|
||||
@@ -11,8 +11,11 @@ from numpy import (
|
||||
broadcast,
|
||||
busday_count,
|
||||
datetime64,
|
||||
diff,
|
||||
dtype,
|
||||
empty,
|
||||
flatnonzero,
|
||||
hstack,
|
||||
nan,
|
||||
vectorize,
|
||||
where
|
||||
@@ -364,3 +367,66 @@ def vectorized_is_element(array, choices):
|
||||
Array indicating whether each element of ``array`` was in ``choices``.
|
||||
"""
|
||||
return vectorize(choices.__contains__, otypes=[bool])(array)
|
||||
|
||||
|
||||
def as_column(a):
|
||||
"""
|
||||
Convert an array of shape (N,) into an array of shape (N, 1).
|
||||
|
||||
This is equivalent to `a[:, np.newaxis]`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : np.ndarray
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import numpy as np
|
||||
>>> a = np.arange(5)
|
||||
>>> a
|
||||
array([0, 1, 2, 3, 4])
|
||||
>>> as_column(a)
|
||||
array([[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[4]])
|
||||
>>> as_column(a).shape
|
||||
(5, 1)
|
||||
"""
|
||||
if a.ndim != 1:
|
||||
raise ValueError(
|
||||
"as_column expected an 1-dimensional array, "
|
||||
"but got an array of shape %s" % a.shape
|
||||
)
|
||||
return a[:, None]
|
||||
|
||||
|
||||
def changed_locations(a, include_first):
|
||||
"""
|
||||
Compute indices of values in ``a`` that differ from the previous value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : np.ndarray
|
||||
The array on which to indices of change.
|
||||
include_first : bool
|
||||
Whether or not to consider the first index of the array as "changed".
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import numpy as np
|
||||
>>> changed_locations(np.array([0, 0, 5, 5, 1, 1]), include_first=False)
|
||||
array([2, 4])
|
||||
|
||||
>>> changed_locations(np.array([0, 0, 5, 5, 1, 1]), include_first=True)
|
||||
array([0, 2, 4])
|
||||
"""
|
||||
if a.ndim > 1:
|
||||
raise ValueError("indices_of_changed_values only supports 1D arrays.")
|
||||
indices = flatnonzero(diff(a)) + 1
|
||||
|
||||
if not include_first:
|
||||
return indices
|
||||
|
||||
return hstack([[0], indices])
|
||||
|
||||
@@ -96,3 +96,53 @@ def mask_between_time(dts, start, end, include_start=True, include_end=True):
|
||||
left_op(start_micros, time_micros),
|
||||
right_op(time_micros, end_micros),
|
||||
)
|
||||
|
||||
|
||||
def nearest_unequal_elements(dts, dt):
|
||||
"""
|
||||
Find values in ``dts`` closest but not equal to ``dt``.
|
||||
|
||||
Returns a pair of (last_before, first_after).
|
||||
|
||||
When ``dt`` is less than any element in ``dts``, ``last_before`` is None.
|
||||
When ``dt`` is greater any element in ``dts``, ``first_after`` is None.
|
||||
|
||||
``dts`` must be unique and sorted in increasing order.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dts : pd.DatetimeIndex
|
||||
Dates in which to search.
|
||||
dt : pd.Timestamp
|
||||
Date for which to find bounds.
|
||||
"""
|
||||
if not dts.is_unique:
|
||||
raise ValueError("dts must be unique")
|
||||
|
||||
if not dts.is_monotonic_increasing:
|
||||
raise ValueError("dts must be sorted in increasing order")
|
||||
|
||||
if not len(dts):
|
||||
return None, None
|
||||
|
||||
sortpos = dts.searchsorted(dt, side='left')
|
||||
try:
|
||||
sortval = dts[sortpos]
|
||||
except IndexError:
|
||||
# dt is greater than any value in the array.
|
||||
return dts[-1], None
|
||||
|
||||
if dt < sortval:
|
||||
lower_ix = sortpos - 1
|
||||
upper_ix = sortpos
|
||||
elif dt == sortval:
|
||||
lower_ix = sortpos - 1
|
||||
upper_ix = sortpos + 1
|
||||
else:
|
||||
lower_ix = sortpos
|
||||
upper_ix = sortpos + 1
|
||||
|
||||
lower_value = dts[lower_ix] if lower_ix >= 0 else None
|
||||
upper_value = dts[upper_ix] if upper_ix < len(dts) else None
|
||||
|
||||
return lower_value, upper_value
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Shared docstrings for parameters that should be documented identically
|
||||
across different functions.
|
||||
"""
|
||||
import re
|
||||
from six import iteritems
|
||||
from textwrap import dedent
|
||||
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC = dedent(
|
||||
"""\
|
||||
frequency : {'year_start', 'quarter_start', 'month_start', 'week_start'}
|
||||
A string indicating desired sampling dates:
|
||||
|
||||
'year_start' -> first trading day of each year
|
||||
'quarter_start' -> first trading day of January, April, July, October
|
||||
'month_start' -> first trading day of each month
|
||||
'week_start' -> first trading_day of each week
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def pad_lines(prefix, s):
|
||||
"""Apply a prefix to each line in s."""
|
||||
return '\n'.join(prefix + line for line in s.splitlines())
|
||||
|
||||
|
||||
def format_docstring(owner_name, docstring, formatters):
|
||||
"""
|
||||
Template ``formatters`` into ``docstring``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
owner_name : str
|
||||
The name of the function or class whose docstring is being templated.
|
||||
Only used for error messages.
|
||||
docstring : str
|
||||
The docstring to template.
|
||||
formatters : dict[str -> str]
|
||||
Parameters for a a str.format() call on ``docstring``.
|
||||
|
||||
Multi-line values in ``formatters`` will have leading whitespace padded
|
||||
to match the leading whitespace of the substitution string.
|
||||
"""
|
||||
# Build a dict of parameters to a vanilla format() call by searching for
|
||||
# each entry in **formatters and applying any leading whitespace to each
|
||||
# line in the desired substitution.
|
||||
format_params = {}
|
||||
for target, doc_for_target in iteritems(formatters):
|
||||
# Search for '{name}', with optional leading whitespace.
|
||||
regex = re.compile('^(\s*)' + '({' + target + '})$', re.MULTILINE)
|
||||
matches = regex.findall(docstring)
|
||||
if not matches:
|
||||
raise ValueError(
|
||||
"Couldn't find template for parameter {!r} in docstring "
|
||||
"for {}."
|
||||
"\nParameter name must be alone on a line surrounded by "
|
||||
"braces.".format(target, owner_name),
|
||||
)
|
||||
elif len(matches) > 1:
|
||||
raise ValueError(
|
||||
"Couldn't found multiple templates for parameter {!r}"
|
||||
"in docstring for {}."
|
||||
"\nParameter should only appear once.".format(
|
||||
target, owner_name
|
||||
)
|
||||
)
|
||||
|
||||
(leading_whitespace, _) = matches[0]
|
||||
format_params[target] = pad_lines(leading_whitespace, doc_for_target)
|
||||
|
||||
return docstring.format(**format_params)
|
||||
|
||||
|
||||
def templated_docstring(**docs):
|
||||
"""
|
||||
Decorator allowing the use of templated docstrings.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> @templated_docstring(foo='bar')
|
||||
... def my_func(self, foo):
|
||||
... '''{foo}'''
|
||||
...
|
||||
>>> my_func.__doc__
|
||||
'bar'
|
||||
"""
|
||||
def decorator(f):
|
||||
f.__doc__ = format_docstring(f.__name__, f.__doc__, docs)
|
||||
return f
|
||||
return decorator
|
||||
Reference in New Issue
Block a user