mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 10:15:34 +08:00
MAINT: Refactor in prep for downsampled terms.
- Split out extra_rows handling into an `ExecutionPlan` subclass.
`ExecutionPlan` now requires the dates and calendar against which a
set of terms will be computed, and now defers to a term's
`compute_extra_rows` method when deciding how many extra rows are
required to compute for that term. This will allow downsampled terms
to request enough extra rows to guarantee that we can maintain consistent
calculation dates.
As a consequence of the above, `TermGraph` now only deals with logical
dependencies, not with metadata surrounding extra row calculations.
This means that TermGraph can be used to generate dependency
visualizations in interactive contexts where we don't yet have a
calendar or start/end dates.
- Refactored test_{filter,factor,classifier} to use check_terms instead
of run_graph. This makes it easier to make changes to TermGraph,
since the testing interface is now to simply provide a dict of terms.
- Refactored BasePipelineTestCase to use fixtures to create an asset
finder. This fixes a potential leak of the test's asset db, which was
not being explicitly cleaned up.
- Refactored test_technical to use BasePipelineTestCase.
- Added a new special term, `InputDates()`, which can be used to request
date labels for inputs. Like `AssetExists`, `InputDates` is provided
in the initial workspace by default.
- Added a default (failing) `_compute` method to `AssetExists` which
provides a more useful error than AttributeError.
This commit is contained in:
+41
-31
@@ -1,24 +1,23 @@
|
||||
"""
|
||||
Base class for Pipeline API unittests.
|
||||
Base class for Pipeline API unit tests.
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from numpy import arange, prod
|
||||
from pandas import date_range, Int64Index, DataFrame
|
||||
from pandas import DataFrame, Timestamp
|
||||
from six import iteritems
|
||||
|
||||
from zipline.assets.synthetic import make_simple_equity_info
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline import TermGraph
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.pipeline import ExecutionPlan
|
||||
from zipline.pipeline.term import AssetExists, InputDates
|
||||
from zipline.testing import (
|
||||
check_arrays,
|
||||
ExplodingObject,
|
||||
tmp_asset_finder,
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithTradingCalendars,
|
||||
WithAssetFinder,
|
||||
WithTradingSessions,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
|
||||
@@ -53,32 +52,26 @@ def with_defaults(**default_funcs):
|
||||
with_default_shape = with_defaults(shape=lambda self: self.default_shape)
|
||||
|
||||
|
||||
class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
class BasePipelineTestCase(WithTradingSessions,
|
||||
WithAssetFinder,
|
||||
ZiplineTestCase):
|
||||
START_DATE = Timestamp('2014', tz='UTC')
|
||||
END_DATE = Timestamp('2014-12-31', tz='UTC')
|
||||
ASSET_FINDER_EQUITY_SIDS = list(range(20))
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(BasePipelineTestCase, cls).init_class_fixtures()
|
||||
|
||||
cls.__calendar = date_range('2014', '2015',
|
||||
freq=cls.trading_calendar.day)
|
||||
cls.__assets = assets = Int64Index(arange(1, 20))
|
||||
cls.__tmp_finder_ctx = tmp_asset_finder(
|
||||
equities=make_simple_equity_info(
|
||||
assets,
|
||||
cls.__calendar[0],
|
||||
cls.__calendar[-1],
|
||||
)
|
||||
)
|
||||
cls.__finder = cls.__tmp_finder_ctx.__enter__()
|
||||
cls.__mask = cls.__finder.lifetimes(
|
||||
cls.__calendar[-30:],
|
||||
cls.default_asset_exists_mask = cls.asset_finder.lifetimes(
|
||||
cls.nyse_sessions[-30:],
|
||||
include_start_date=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
"""Default shape for methods that build test data."""
|
||||
return self.__mask.shape
|
||||
return self.default_asset_exists_mask.shape
|
||||
|
||||
def run_graph(self, graph, initial_workspace, mask=None):
|
||||
"""
|
||||
@@ -103,14 +96,17 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
"""
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: ExplodingObject(),
|
||||
self.__calendar,
|
||||
self.__finder,
|
||||
self.nyse_sessions,
|
||||
self.asset_finder,
|
||||
)
|
||||
if mask is None:
|
||||
mask = self.__mask
|
||||
mask = self.default_asset_exists_mask
|
||||
|
||||
dates, assets, mask_values = explode(mask)
|
||||
|
||||
initial_workspace.setdefault(AssetExists(), mask_values)
|
||||
initial_workspace.setdefault(InputDates(), dates)
|
||||
|
||||
return engine.compute_chunk(
|
||||
graph,
|
||||
dates,
|
||||
@@ -118,15 +114,29 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
initial_workspace,
|
||||
)
|
||||
|
||||
def check_terms(self, terms, expected, initial_workspace, mask):
|
||||
def check_terms(self,
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace,
|
||||
mask,
|
||||
check=check_arrays):
|
||||
"""
|
||||
Compile the given terms into a TermGraph, compute it with
|
||||
initial_workspace, and compare the results with ``expected``.
|
||||
"""
|
||||
graph = TermGraph(terms)
|
||||
start_date, end_date = mask.index[[0, -1]]
|
||||
graph = ExecutionPlan(
|
||||
terms,
|
||||
all_dates=self.nyse_sessions,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
results = self.run_graph(graph, initial_workspace, mask)
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_arrays(res, exp)
|
||||
check(res, exp)
|
||||
|
||||
return results
|
||||
|
||||
def build_mask(self, array):
|
||||
"""
|
||||
@@ -138,13 +148,13 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
|
||||
array,
|
||||
# Use the **last** N dates rather than the first N so that we have
|
||||
# space for lookbacks.
|
||||
index=self.__calendar[-ndates:],
|
||||
columns=self.__assets[:nassets],
|
||||
index=self.nyse_sessions[-ndates:],
|
||||
columns=self.ASSET_FINDER_EQUITY_SIDS[:nassets],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
@with_default_shape
|
||||
def arange_data(self, shape, dtype=float):
|
||||
def arange_data(self, shape, dtype=np.float64):
|
||||
"""
|
||||
Build a block of testing data from numpy.arange.
|
||||
"""
|
||||
|
||||
@@ -40,6 +40,7 @@ from six import iteritems, itervalues
|
||||
from toolz import merge
|
||||
|
||||
from zipline.assets.synthetic import make_rotating_equity_info
|
||||
from zipline.errors import NoFurtherDataError
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.pipeline import CustomFactor, Pipeline
|
||||
@@ -65,6 +66,7 @@ from zipline.pipeline.loaders.synthetic import (
|
||||
expected_bar_values_2d,
|
||||
)
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.pipeline.term import InputDates
|
||||
from zipline.testing import (
|
||||
AssetID,
|
||||
AssetIDPlusDay,
|
||||
@@ -81,7 +83,7 @@ from zipline.testing.fixtures import (
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
from zipline.utils.numpy_utils import bool_dtype, datetime64ns_dtype
|
||||
|
||||
|
||||
class RollingSumDifference(CustomFactor):
|
||||
@@ -229,6 +231,31 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
with self.assertRaises(NoFurtherDataError) as e:
|
||||
engine.run_pipeline(p, self.dates[8], self.dates[8])
|
||||
|
||||
def test_input_dates_provided_by_default(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
class TestFactor(CustomFactor):
|
||||
inputs = [InputDates(), USEquityPricing.close]
|
||||
window_length = 10
|
||||
dtype = datetime64ns_dtype
|
||||
|
||||
def compute(self, today, assets, out, dates, closes):
|
||||
first, last = dates[[0, -1], 0]
|
||||
assert last == today.asm8
|
||||
assert len(dates) == len(closes) == self.window_length
|
||||
out[:] = first
|
||||
|
||||
p = Pipeline(columns={'t': TestFactor()})
|
||||
results = engine.run_pipeline(p, self.dates[9], self.dates[10])
|
||||
|
||||
# All results are the same, so just grab one column.
|
||||
column = results.unstack().iloc[:, 0].values
|
||||
check_arrays(column, self.dates[:2].values)
|
||||
|
||||
|
||||
def test_same_day_pipeline(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(
|
||||
|
||||
@@ -123,20 +123,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
data = arange(25).reshape(5, 5)
|
||||
data[eye(5, dtype=bool)] = custom_missing_value
|
||||
|
||||
graph = TermGraph(
|
||||
self.check_terms(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
},
|
||||
{
|
||||
'isnull': eye(5, dtype=bool),
|
||||
'notnull': ~eye(5, dtype=bool),
|
||||
},
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
def test_isnull_datetime_dtype(self):
|
||||
class DatetimeFactor(Factor):
|
||||
@@ -149,20 +147,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
data = arange(25).reshape(5, 5).astype('datetime64[ns]')
|
||||
data[eye(5, dtype=bool)] = NaTns
|
||||
|
||||
graph = TermGraph(
|
||||
self.check_terms(
|
||||
{
|
||||
'isnull': factor.isnull(),
|
||||
'notnull': factor.notnull(),
|
||||
}
|
||||
)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
},
|
||||
{
|
||||
'isnull': eye(5, dtype=bool),
|
||||
'notnull': ~eye(5, dtype=bool),
|
||||
},
|
||||
initial_workspace={factor: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
check_arrays(results['isnull'], eye(5, dtype=bool))
|
||||
check_arrays(results['notnull'], ~eye(5, dtype=bool))
|
||||
|
||||
@for_each_factor_dtype
|
||||
def test_rank_ascending(self, name, factor_dtype):
|
||||
@@ -206,14 +202,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({meth: f.rank(method=meth) for meth in expected_ranks})
|
||||
check({
|
||||
@@ -265,14 +259,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({
|
||||
meth: f.rank(method=meth, ascending=False)
|
||||
@@ -294,14 +286,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask_data = ~eye(5, dtype=bool)
|
||||
initial_workspace = {f: data, Mask(): mask_data}
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
"ascending_nomask": f.rank(ascending=True),
|
||||
"ascending_mask": f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": f.rank(ascending=False),
|
||||
"descending_mask": f.rank(ascending=False, mask=Mask()),
|
||||
}
|
||||
)
|
||||
terms = {
|
||||
"ascending_nomask": f.rank(ascending=True),
|
||||
"ascending_mask": f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": f.rank(ascending=False),
|
||||
"descending_mask": f.rank(ascending=False, mask=Mask()),
|
||||
}
|
||||
|
||||
expected = {
|
||||
"ascending_nomask": array([[1., 3., 4., 5., 2.],
|
||||
@@ -328,13 +318,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[4., 3., 2., 1., nan]]),
|
||||
}
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace,
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in results:
|
||||
check_arrays(expected[method], results[method])
|
||||
|
||||
@for_each_factor_dtype
|
||||
def test_grouped_rank_ascending(self, name, factor_dtype=float64_dtype):
|
||||
@@ -363,7 +352,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
expected_ranks = {
|
||||
'ordinal': array(
|
||||
[[1., 1., 3., 2., 2.],
|
||||
[1., 2., 3., 1., 2.],
|
||||
@@ -402,9 +391,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
@@ -413,25 +402,22 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
# Not specifying the value of ascending param should default to True
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
@@ -468,7 +454,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
expected_ranks = {
|
||||
'ordinal': array(
|
||||
[[2., 2., 1., 1., 3.],
|
||||
[2., 1., 1., 2., 3.],
|
||||
@@ -507,9 +493,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected={name: expected_ranks[name] for name in terms},
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
@@ -518,16 +504,13 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
for meth in expected_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
@@ -707,9 +690,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
expected['grouped_str'] = expected['grouped']
|
||||
expected['grouped_masked_str'] = expected['grouped_masked']
|
||||
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
c: classifier_data,
|
||||
@@ -717,20 +700,13 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
m: filter_data,
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
|
||||
# The hand-computed values aren't very precise (in particular,
|
||||
# we truncate repeating decimals at 3 places) This is just
|
||||
# asserting that the example isn't misleading by being totally
|
||||
# wrong.
|
||||
check=partial(check_allclose, atol=0.001),
|
||||
)
|
||||
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_allclose(
|
||||
res,
|
||||
exp,
|
||||
# The hand-computed values aren't very precise (in particular,
|
||||
# we truncate repeating decimals at 3 places) This is just
|
||||
# asserting that the example isn't misleading by being totally
|
||||
# wrong.
|
||||
atol=0.001,
|
||||
err_msg="Mismatch for %r" % key
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
seed_value=range(1, 2),
|
||||
normalizer_name_and_func=[
|
||||
|
||||
+153
-176
@@ -12,7 +12,6 @@ from numpy import (
|
||||
array,
|
||||
eye,
|
||||
float64,
|
||||
full_like,
|
||||
full,
|
||||
inf,
|
||||
isfinite,
|
||||
@@ -127,7 +126,6 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
nan_data[:, 0] = nan
|
||||
|
||||
mask = Mask()
|
||||
workspace = {self.f: data, mask: mask_data}
|
||||
|
||||
methods = ['top', 'bottom']
|
||||
counts = 2, 3, 10
|
||||
@@ -136,18 +134,6 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
def termname(method, count, masked):
|
||||
return '_'.join([method, str(count), 'mask' if masked else ''])
|
||||
|
||||
# Add a term for each permutation of top/bottom, count, and
|
||||
# mask/no_mask.
|
||||
terms = {}
|
||||
for method, count, masked in term_combos:
|
||||
kwargs = {'N': count}
|
||||
if masked:
|
||||
kwargs['mask'] = mask
|
||||
term = getattr(self.f, method)(**kwargs)
|
||||
terms[termname(method, count, masked)] = term
|
||||
|
||||
results = self.run_graph(TermGraph(terms), initial_workspace=workspace)
|
||||
|
||||
def expected_result(method, count, masked):
|
||||
# Ranking with a mask is equivalent to ranking with nans applied on
|
||||
# the masked values.
|
||||
@@ -158,72 +144,55 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
elif method == 'bottom':
|
||||
return rowwise_rank(to_rank) < count
|
||||
|
||||
# Add a term for each permutation of top/bottom, count, and
|
||||
# mask/no_mask.
|
||||
terms = {}
|
||||
expected = {}
|
||||
for method, count, masked in term_combos:
|
||||
result = results[termname(method, count, masked)]
|
||||
kwargs = {'N': count}
|
||||
if masked:
|
||||
kwargs['mask'] = mask
|
||||
term = getattr(self.f, method)(**kwargs)
|
||||
name = termname(method, count, masked)
|
||||
terms[name] = term
|
||||
expected[name] = expected_result(method, count, masked)
|
||||
|
||||
# Check that `min(c, num_assets)` assets passed each day.
|
||||
passed_per_day = result.sum(axis=1)
|
||||
check_arrays(
|
||||
passed_per_day,
|
||||
full_like(passed_per_day, min(count, data.shape[1])),
|
||||
)
|
||||
|
||||
expected = expected_result(method, count, masked)
|
||||
check_arrays(result, expected)
|
||||
|
||||
def test_bottom(self):
|
||||
counts = 2, 3, 10
|
||||
data = self.randn_data(seed=5) # Arbitrary seed choice.
|
||||
results = self.run_graph(
|
||||
TermGraph(
|
||||
{'bottom_' + str(c): self.f.bottom(c) for c in counts}
|
||||
),
|
||||
initial_workspace={self.f: data},
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: data, mask: mask_data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
for c in counts:
|
||||
result = results['bottom_' + str(c)]
|
||||
|
||||
# Check that `min(c, num_assets)` assets passed each day.
|
||||
passed_per_day = result.sum(axis=1)
|
||||
check_arrays(
|
||||
passed_per_day,
|
||||
full_like(passed_per_day, min(c, data.shape[1])),
|
||||
)
|
||||
|
||||
# Check that the bottom `c` assets passed.
|
||||
expected = rowwise_rank(data) < c
|
||||
check_arrays(result, expected)
|
||||
|
||||
def test_percentile_between(self):
|
||||
|
||||
quintiles = range(5)
|
||||
filter_names = ['pct_' + str(q) for q in quintiles]
|
||||
iter_quintiles = zip(filter_names, quintiles)
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
|
||||
for name, q in zip(filter_names, quintiles)
|
||||
}
|
||||
)
|
||||
terms = {
|
||||
name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0)
|
||||
for name, q in iter_quintiles
|
||||
}
|
||||
|
||||
# Test with 5 columns and no NaNs.
|
||||
eye5 = eye(5, dtype=float64)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye5},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# There are four 0s and one 1 in each row, so the first 4
|
||||
# quintiles should be all the locations with zeros in the input
|
||||
# array.
|
||||
check_arrays(result, ~eye5.astype(bool))
|
||||
expected[name] = ~eye5.astype(bool)
|
||||
else:
|
||||
# The top quintile should match the sole 1 in each row.
|
||||
check_arrays(result, eye5.astype(bool))
|
||||
expected[name] = eye5.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={self.f: eye5},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
# Test with 6 columns, no NaNs, and one masked entry per day.
|
||||
eye6 = eye(6, dtype=float64)
|
||||
@@ -233,41 +202,44 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
[1, 1, 0, 1, 1, 1],
|
||||
[1, 1, 1, 0, 1, 1],
|
||||
[1, 1, 1, 1, 0, 1]], dtype=bool)
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask)
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# Should keep all values that were 0 in the base data and were
|
||||
# 1 in the mask.
|
||||
check_arrays(result, mask & (~eye6.astype(bool))),
|
||||
expected[name] = mask & ~eye6.astype(bool)
|
||||
else:
|
||||
# Should keep all the 1s in the base data.
|
||||
check_arrays(result, eye6.astype(bool))
|
||||
# The top quintile should match the sole 1 in each row.
|
||||
expected[name] = eye6.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
# Test with 6 columns, no mask, and one NaN per day. Should have the
|
||||
# same outcome as if we had masked the NaNs.
|
||||
# In particular, the NaNs should never pass any filters.
|
||||
eye6_withnans = eye6.copy()
|
||||
putmask(eye6_withnans, ~mask, nan)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask)
|
||||
)
|
||||
expected = {}
|
||||
for name, quintile in iter_quintiles:
|
||||
result = results[name]
|
||||
if quintile < 4:
|
||||
# Should keep all values that were 0 in the base data and were
|
||||
# 1 in the mask.
|
||||
check_arrays(result, mask & (~eye6.astype(bool))),
|
||||
expected[name] = mask & (~eye6.astype(bool))
|
||||
else:
|
||||
# Should keep all the 1s in the base data.
|
||||
check_arrays(result, eye6.astype(bool))
|
||||
expected[name] = eye6.astype(bool)
|
||||
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: eye6},
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
def test_percentile_nasty_partitions(self):
|
||||
# Test percentile with nasty partitions: divide up 5 assets into
|
||||
@@ -281,27 +253,26 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
quartiles = range(4)
|
||||
filter_names = ['pct_' + str(q) for q in quartiles]
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
|
||||
for name, q in zip(filter_names, quartiles)
|
||||
}
|
||||
)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
terms = {
|
||||
name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0)
|
||||
for name, q in zip(filter_names, quartiles)
|
||||
}
|
||||
|
||||
expected = {}
|
||||
for name, quartile in zip(filter_names, quartiles):
|
||||
result = results[name]
|
||||
lower = quartile * 25.0
|
||||
upper = (quartile + 1) * 25.0
|
||||
expected = and_(
|
||||
expected[name] = and_(
|
||||
nanpercentile(data, lower, axis=1, keepdims=True) <= data,
|
||||
data <= nanpercentile(data, upper, axis=1, keepdims=True),
|
||||
)
|
||||
check_arrays(result, expected)
|
||||
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
def test_percentile_after_mask(self):
|
||||
f_input = eye(5)
|
||||
@@ -312,77 +283,79 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
without_mask = self.g.percentile_between(80, 100)
|
||||
with_mask = self.g.percentile_between(80, 100, mask=custom_mask)
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
'custom_mask': custom_mask,
|
||||
'without': without_mask,
|
||||
'with': with_mask,
|
||||
}
|
||||
)
|
||||
terms = {
|
||||
'mask': custom_mask,
|
||||
'without_mask': without_mask,
|
||||
'with_mask': with_mask,
|
||||
}
|
||||
expected = {
|
||||
# Mask that accepts everything except the diagonal.
|
||||
'mask': ~eye(5, dtype=bool),
|
||||
# Second should pass the largest value each day. Each row is
|
||||
# strictly increasing, so we always select the last value.
|
||||
'without_mask': array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1]],
|
||||
dtype=bool,
|
||||
),
|
||||
# With a mask, we should remove the diagonal as an option before
|
||||
# computing percentiles. On the last day, we should get the
|
||||
# second-largest value, rather than the largest.
|
||||
'with_mask': array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0]], # Different from with!
|
||||
dtype=bool,
|
||||
),
|
||||
}
|
||||
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected,
|
||||
initial_workspace={self.f: f_input, self.g: g_input},
|
||||
mask=initial_mask,
|
||||
)
|
||||
|
||||
# First should pass everything but the diagonal.
|
||||
check_arrays(results['custom_mask'], ~eye(5, dtype=bool))
|
||||
|
||||
# Second should pass the largest value each day. Each row is strictly
|
||||
# increasing, so we always select the last value.
|
||||
expected_without = array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1]],
|
||||
dtype=bool,
|
||||
)
|
||||
check_arrays(results['without'], expected_without)
|
||||
|
||||
# When sequencing, we should remove the diagonal as an option before
|
||||
# computing percentiles. On the last day, we should get the
|
||||
# second-largest value, rather than the largest.
|
||||
expected_with = array(
|
||||
[[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0]], # Different from previous!
|
||||
dtype=bool,
|
||||
)
|
||||
check_arrays(results['with'], expected_with)
|
||||
|
||||
def test_isnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'isnan': self.f.isnan(),
|
||||
'isnull': self.f.isnull(),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'isnan': diag,
|
||||
'isnull': diag,
|
||||
},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['isnan'], diag)
|
||||
check_arrays(results['isnull'], diag)
|
||||
|
||||
def test_notnan(self):
|
||||
data = self.randn_data(seed=10)
|
||||
diag = eye(*data.shape, dtype=bool)
|
||||
data[diag] = nan
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'notnan': self.f.notnan(),
|
||||
'notnull': self.f.notnull(),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'notnan': ~diag,
|
||||
'notnull': ~diag,
|
||||
},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['notnan'], ~diag)
|
||||
check_arrays(results['notnull'], ~diag)
|
||||
|
||||
def test_isfinite(self):
|
||||
data = self.randn_data(seed=10)
|
||||
@@ -390,11 +363,12 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
data[:, 2] = inf
|
||||
data[:, 4] = -inf
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'isfinite': self.f.isfinite()}),
|
||||
self.check_terms(
|
||||
terms={'isfinite': self.f.isfinite()},
|
||||
expected={'isfinite': isfinite(data)},
|
||||
initial_workspace={self.f: data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
check_arrays(results['isfinite'], isfinite(data))
|
||||
|
||||
def test_all(self):
|
||||
|
||||
@@ -427,18 +401,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'3': All(inputs=[Input()], window_length=3),
|
||||
'4': All(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'3': expected_3,
|
||||
'4': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_any(self):
|
||||
|
||||
# FUN FACT: The inputs and outputs here are exactly the negation of
|
||||
@@ -486,18 +461,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'3': Any(inputs=[Input()], window_length=3),
|
||||
'4': Any(inputs=[Input()], window_length=4),
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'3': expected_3,
|
||||
'4': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['3'], expected_3)
|
||||
check_arrays(results['4'], expected_4)
|
||||
|
||||
def test_at_least_N(self):
|
||||
|
||||
# With a window_length of K, AtLeastN should return 1
|
||||
@@ -553,26 +529,27 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
window_length=4,
|
||||
N=4)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({
|
||||
self.check_terms(
|
||||
terms={
|
||||
'AllButOne': all_but_one,
|
||||
'AllButTwo': all_but_two,
|
||||
'AnyEquiv': any_equiv,
|
||||
'AllEquiv': all_equiv,
|
||||
'Any': Any(inputs=[Input()], window_length=4),
|
||||
'All': All(inputs=[Input()], window_length=4)
|
||||
}),
|
||||
},
|
||||
expected={
|
||||
'Any': expected_1,
|
||||
'AnyEquiv': expected_1,
|
||||
'AllButTwo': expected_2,
|
||||
'AllButOne': expected_3,
|
||||
'All': expected_4,
|
||||
'AllEquiv': expected_4,
|
||||
},
|
||||
initial_workspace={Input(): data},
|
||||
mask=self.build_mask(ones(shape=data.shape)),
|
||||
)
|
||||
|
||||
check_arrays(results['Any'], expected_1)
|
||||
check_arrays(results['AnyEquiv'], expected_1)
|
||||
check_arrays(results['AllButTwo'], expected_2)
|
||||
check_arrays(results['AllButOne'], expected_3)
|
||||
check_arrays(results['All'], expected_4)
|
||||
check_arrays(results['AllEquiv'], expected_4)
|
||||
|
||||
@parameter_space(factor_len=[2, 3, 4])
|
||||
def test_window_safe(self, factor_len):
|
||||
# all true data set of (days, securities)
|
||||
@@ -591,19 +568,19 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
# sum for each column
|
||||
out[:] = np_sum(filter_, axis=0)
|
||||
|
||||
results = self.run_graph(
|
||||
TermGraph({'windowsafe': TestFactor()}),
|
||||
initial_workspace={InputFilter(): data},
|
||||
)
|
||||
|
||||
# number of days in default_shape
|
||||
n = self.default_shape[0]
|
||||
|
||||
# shape of output array
|
||||
output_shape = ((n - factor_len + 1), self.default_shape[1])
|
||||
check_arrays(
|
||||
results['windowsafe'],
|
||||
full(output_shape, factor_len, dtype=float64)
|
||||
full(output_shape, factor_len, dtype=float64)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'windowsafe': TestFactor(),
|
||||
},
|
||||
expected={
|
||||
'windowsafe': full(output_shape, factor_len, dtype=float64),
|
||||
},
|
||||
initial_workspace={InputFilter(): data},
|
||||
mask=self.build_mask(self.ones_mask()),
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
|
||||
@@ -7,10 +7,7 @@ import pandas as pd
|
||||
import talib
|
||||
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.pipeline import TermGraph
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.pipeline.factors import (
|
||||
BollingerBands,
|
||||
Aroon,
|
||||
@@ -20,61 +17,22 @@ from zipline.pipeline.factors import (
|
||||
RateOfChangePercentage,
|
||||
TrueRange,
|
||||
)
|
||||
from zipline.testing import ExplodingObject, parameter_space
|
||||
from zipline.testing.fixtures import WithAssetFinder, ZiplineTestCase
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import ZiplineTestCase
|
||||
from zipline.testing.predicates import assert_equal
|
||||
|
||||
|
||||
class WithTechnicalFactor(WithAssetFinder):
|
||||
"""ZiplineTestCase fixture for testing technical factors.
|
||||
"""
|
||||
ASSET_FINDER_EQUITY_SIDS = tuple(range(5))
|
||||
START_DATE = pd.Timestamp('2014-01-01', tz='utc')
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithTechnicalFactor, cls).init_class_fixtures()
|
||||
cls.ndays = ndays = 24
|
||||
cls.nassets = nassets = len(cls.ASSET_FINDER_EQUITY_SIDS)
|
||||
cls.dates = dates = pd.date_range(cls.START_DATE, periods=ndays)
|
||||
cls.assets = pd.Index(cls.asset_finder.sids)
|
||||
cls.engine = SimplePipelineEngine(
|
||||
lambda column: ExplodingObject(),
|
||||
dates,
|
||||
cls.asset_finder,
|
||||
)
|
||||
cls.asset_exists = exists = np.full((ndays, nassets), True, dtype=bool)
|
||||
cls.asset_exists_masked = masked = exists.copy()
|
||||
masked[:, -1] = False
|
||||
|
||||
def run_graph(self, graph, initial_workspace, mask_sid):
|
||||
initial_workspace.setdefault(
|
||||
AssetExists(),
|
||||
self.asset_exists_masked if mask_sid else self.asset_exists,
|
||||
)
|
||||
return self.engine.compute_chunk(
|
||||
graph,
|
||||
self.dates,
|
||||
self.assets,
|
||||
initial_workspace,
|
||||
)
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
|
||||
class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(BollingerBandsTestCase, cls).init_class_fixtures()
|
||||
cls._closes = closes = (
|
||||
np.arange(cls.ndays, dtype=float)[:, np.newaxis] +
|
||||
np.arange(cls.nassets, dtype=float) * 100
|
||||
)
|
||||
cls._closes_masked = masked = closes.copy()
|
||||
masked[:, -1] = np.nan
|
||||
class BollingerBandsTestCase(BasePipelineTestCase):
|
||||
|
||||
def closes(self, masked):
|
||||
return self._closes_masked if masked else self._closes
|
||||
def closes(self, mask_last_sid):
|
||||
data = self.arange_data(dtype=np.float64)
|
||||
if mask_last_sid:
|
||||
data[:, -1] = np.nan
|
||||
return data
|
||||
|
||||
def expected(self, window_length, k, closes):
|
||||
def expected_bbands(self, window_length, k, closes):
|
||||
"""Compute the expected data (without adjustments) for the given
|
||||
window, k, and closes array.
|
||||
|
||||
@@ -83,11 +41,14 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
lower_cols = []
|
||||
middle_cols = []
|
||||
upper_cols = []
|
||||
for n in range(self.nassets):
|
||||
|
||||
ndates, nassets = closes.shape
|
||||
|
||||
for n in range(nassets):
|
||||
close_col = closes[:, n]
|
||||
if np.isnan(close_col).all():
|
||||
# ta-lib doesn't deal well with all nans.
|
||||
upper, middle, lower = [np.full(self.ndays, np.nan)] * 3
|
||||
upper, middle, lower = [np.full(ndates, np.nan)] * 3
|
||||
else:
|
||||
upper, middle, lower = talib.BBANDS(
|
||||
close_col,
|
||||
@@ -112,38 +73,38 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase):
|
||||
@parameter_space(
|
||||
window_length={5, 10, 20},
|
||||
k={1.5, 2, 2.5},
|
||||
mask_sid={True, False},
|
||||
mask_last_sid={True, False},
|
||||
__fail_fast=True,
|
||||
)
|
||||
def test_bollinger_bands(self, window_length, k, mask_sid):
|
||||
closes = self.closes(mask_sid)
|
||||
result = self.run_graph(
|
||||
TermGraph({
|
||||
'f': BollingerBands(
|
||||
window_length=window_length,
|
||||
k=k,
|
||||
),
|
||||
}),
|
||||
def test_bollinger_bands(self, window_length, k, mask_last_sid):
|
||||
closes = self.closes(mask_last_sid=mask_last_sid)
|
||||
mask = ~np.isnan(closes)
|
||||
bbands = BollingerBands(window_length=window_length, k=k)
|
||||
|
||||
expected = self.expected_bbands(window_length, k, closes)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'upper': bbands.upper,
|
||||
'middle': bbands.middle,
|
||||
'lower': bbands.lower,
|
||||
},
|
||||
expected={
|
||||
'upper': expected[0],
|
||||
'middle': expected[1],
|
||||
'lower': expected[2],
|
||||
},
|
||||
initial_workspace={
|
||||
USEquityPricing.close: AdjustedArray(
|
||||
closes,
|
||||
np.full_like(closes, True, dtype=bool),
|
||||
{},
|
||||
np.nan,
|
||||
data=closes,
|
||||
mask=mask,
|
||||
adjustments={},
|
||||
missing_value=np.nan,
|
||||
),
|
||||
},
|
||||
mask_sid=mask_sid,
|
||||
)['f']
|
||||
|
||||
expected_upper, expected_middle, expected_lower = self.expected(
|
||||
window_length,
|
||||
k,
|
||||
closes,
|
||||
mask=self.build_mask(mask),
|
||||
)
|
||||
|
||||
assert_equal(result.upper, expected_upper)
|
||||
assert_equal(result.middle, expected_middle)
|
||||
assert_equal(result.lower, expected_lower)
|
||||
|
||||
def test_bollinger_bands_output_ordering(self):
|
||||
bbands = BollingerBands(window_length=5, k=2)
|
||||
lower, middle, upper = bbands
|
||||
@@ -185,7 +146,7 @@ class AroonTestCase(ZiplineTestCase):
|
||||
assert_equal(out, expected_out)
|
||||
|
||||
|
||||
class TestFastStochasticOscillator(WithTechnicalFactor, ZiplineTestCase):
|
||||
class TestFastStochasticOscillator(ZiplineTestCase):
|
||||
"""
|
||||
Test the Fast Stochastic Oscillator
|
||||
"""
|
||||
@@ -427,7 +388,7 @@ class TestLinearWeightedMovingAverage(ZiplineTestCase):
|
||||
assert_equal(out, np.array([30., 31., 32., 33., 34.]))
|
||||
|
||||
|
||||
class TestTrueRange(WithTechnicalFactor, ZiplineTestCase):
|
||||
class TestTrueRange(ZiplineTestCase):
|
||||
|
||||
def test_tr_basic(self):
|
||||
tr = TrueRange()
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import product
|
||||
from unittest import TestCase
|
||||
|
||||
from toolz import assoc
|
||||
import pandas as pd
|
||||
|
||||
from zipline.assets import Asset
|
||||
from zipline.errors import (
|
||||
@@ -24,7 +25,7 @@ from zipline.pipeline import (
|
||||
CustomFactor,
|
||||
Factor,
|
||||
Filter,
|
||||
TermGraph,
|
||||
ExecutionPlan,
|
||||
)
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
@@ -33,6 +34,7 @@ from zipline.pipeline.factors import RecarrayField
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.pipeline.term import AssetExists, Slice
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import WithTradingSessions, ZiplineTestCase
|
||||
from zipline.testing.predicates import (
|
||||
assert_equal,
|
||||
assert_raises,
|
||||
@@ -152,7 +154,14 @@ def to_dict(l):
|
||||
return dict(zip(map(str, range(len(l))), l))
|
||||
|
||||
|
||||
class DependencyResolutionTestCase(TestCase):
|
||||
class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
|
||||
|
||||
TRADING_CALENDAR_STRS = ('NYSE',)
|
||||
START_DATE = pd.Timestamp('2014-01-02', tz='UTC')
|
||||
END_DATE = pd.Timestamp('2014-12-31', tz='UTC')
|
||||
|
||||
execution_plan_start = pd.Timestamp('2014-06-01', tz='UTC')
|
||||
execution_plan_end = pd.Timestamp('2014-06-30', tz='UTC')
|
||||
|
||||
def check_dependency_order(self, ordered_terms):
|
||||
seen = set()
|
||||
@@ -163,6 +172,14 @@ class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
seen.add(term)
|
||||
|
||||
def make_execution_plan(self, terms):
|
||||
return ExecutionPlan(
|
||||
terms,
|
||||
self.nyse_sessions,
|
||||
self.execution_plan_start,
|
||||
self.execution_plan_end,
|
||||
)
|
||||
|
||||
def test_single_factor(self):
|
||||
"""
|
||||
Test dependency resolution for a single factor.
|
||||
@@ -182,7 +199,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
|
||||
|
||||
for foobar in gen_equivalent_factors():
|
||||
check_output(TermGraph(to_dict([foobar])))
|
||||
check_output(self.make_execution_plan(to_dict([foobar])))
|
||||
|
||||
def test_single_factor_instance_args(self):
|
||||
"""
|
||||
@@ -190,7 +207,9 @@ class DependencyResolutionTestCase(TestCase):
|
||||
the constructor.
|
||||
"""
|
||||
bar, buzz = SomeDataSet.bar, SomeDataSet.buzz
|
||||
graph = TermGraph(to_dict([SomeFactor([bar, buzz], window_length=5)]))
|
||||
|
||||
factor = SomeFactor([bar, buzz], window_length=5)
|
||||
graph = self.make_execution_plan(to_dict([factor]))
|
||||
|
||||
resolution_order = list(graph.ordered())
|
||||
|
||||
@@ -214,7 +233,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
|
||||
f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz])
|
||||
|
||||
graph = TermGraph(to_dict([f1, f2]))
|
||||
graph = self.make_execution_plan(to_dict([f1, f2]))
|
||||
resolution_order = list(graph.ordered())
|
||||
|
||||
# bar should only appear once.
|
||||
|
||||
@@ -6,7 +6,7 @@ from .engine import SimplePipelineEngine
|
||||
from .factors import Factor, CustomFactor
|
||||
from .filters import Filter, CustomFilter
|
||||
from .term import Term
|
||||
from .graph import TermGraph
|
||||
from .graph import ExecutionPlan, TermGraph
|
||||
from .pipeline import Pipeline
|
||||
from .loaders import USEquityPricingLoader
|
||||
|
||||
@@ -56,6 +56,7 @@ __all__ = (
|
||||
'CustomFilter',
|
||||
'CustomClassifier',
|
||||
'engine_from_files',
|
||||
'ExecutionPlan',
|
||||
'Factor',
|
||||
'Filter',
|
||||
'Pipeline',
|
||||
|
||||
@@ -21,7 +21,7 @@ from zipline.errors import NoFurtherDataError
|
||||
from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis
|
||||
from zipline.utils.pandas_utils import explode
|
||||
|
||||
from .term import AssetExists, LoadableTerm
|
||||
from .term import AssetExists, InputDates, LoadableTerm
|
||||
|
||||
|
||||
class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
@@ -98,6 +98,7 @@ class SimplePipelineEngine(object):
|
||||
'_calendar',
|
||||
'_finder',
|
||||
'_root_mask_term',
|
||||
'_root_mask_dates_term',
|
||||
'__weakref__',
|
||||
)
|
||||
|
||||
@@ -105,7 +106,9 @@ class SimplePipelineEngine(object):
|
||||
self._get_loader = get_loader
|
||||
self._calendar = calendar
|
||||
self._finder = asset_finder
|
||||
|
||||
self._root_mask_term = AssetExists()
|
||||
self._root_mask_dates_term = InputDates()
|
||||
|
||||
def run_pipeline(self, pipeline, start_date, end_date):
|
||||
"""
|
||||
@@ -161,7 +164,13 @@ class SimplePipelineEngine(object):
|
||||
)
|
||||
|
||||
screen_name = uuid4().hex
|
||||
graph = pipeline.to_graph(screen_name, self._root_mask_term)
|
||||
graph = pipeline.to_execution_plan(
|
||||
screen_name,
|
||||
self._root_mask_term,
|
||||
self._calendar,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
extra_rows = graph.extra_rows[self._root_mask_term]
|
||||
root_mask = self._compute_root_mask(start_date, end_date, extra_rows)
|
||||
dates, assets, root_mask_values = explode(root_mask)
|
||||
@@ -170,7 +179,10 @@ class SimplePipelineEngine(object):
|
||||
graph,
|
||||
dates,
|
||||
assets,
|
||||
initial_workspace={self._root_mask_term: root_mask_values},
|
||||
initial_workspace={
|
||||
self._root_mask_term: root_mask_values,
|
||||
self._root_mask_dates_term: dates.values[:, None],
|
||||
},
|
||||
)
|
||||
|
||||
return self._to_narrow(
|
||||
|
||||
+158
-76
@@ -5,7 +5,7 @@ from networkx import (
|
||||
DiGraph,
|
||||
topological_sort,
|
||||
)
|
||||
from six import itervalues, iteritems
|
||||
from six import iteritems
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.pipeline.visualize import display_graph
|
||||
|
||||
@@ -18,7 +18,112 @@ class CyclicDependency(Exception):
|
||||
|
||||
class TermGraph(DiGraph):
|
||||
"""
|
||||
Graph represention of Pipeline Term dependencies.
|
||||
An abstract representation of Pipeline Term dependencies.
|
||||
|
||||
This class does not keep any additional metadata about any term relations
|
||||
other than dependency ordering. As such it is only useful in contexts
|
||||
where you care exclusively about order properties (for example, when
|
||||
drawing visualizations of execution order).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
terms : dict
|
||||
A dict mapping names to final output terms.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
outputs
|
||||
|
||||
Methods
|
||||
-------
|
||||
ordered()
|
||||
Return a topologically-sorted iterator over the terms in self.
|
||||
|
||||
See Also
|
||||
--------
|
||||
ExecutionPlan
|
||||
"""
|
||||
def __init__(self, terms):
|
||||
super(TermGraph, self).__init__()
|
||||
|
||||
self._frozen = False
|
||||
parents = set()
|
||||
for name, term in iteritems(terms):
|
||||
self._add_to_graph(term, parents)
|
||||
# No parents should be left between top-level terms.
|
||||
assert not parents
|
||||
|
||||
self._outputs = terms
|
||||
self._ordered = topological_sort(self)
|
||||
|
||||
# Mark that no more terms should be added to the graph.
|
||||
self._frozen = True
|
||||
|
||||
def _add_to_graph(self, term, parents):
|
||||
"""
|
||||
Add a term and all its children to ``graph``.
|
||||
|
||||
``parents`` is the set of all the parents of ``term` that we've added
|
||||
so far. It is only used to detect dependency cycles.
|
||||
"""
|
||||
if self._frozen:
|
||||
raise ValueError(
|
||||
"Can't mutate %s after construction." % type(self).__name__
|
||||
)
|
||||
|
||||
# If we've seen this node already as a parent of the current traversal,
|
||||
# it means we have an unsatisifiable dependency. This should only be
|
||||
# possible if the term's inputs are mutated after construction.
|
||||
if term in parents:
|
||||
raise CyclicDependency(term)
|
||||
|
||||
parents.add(term)
|
||||
|
||||
self.add_node(term)
|
||||
|
||||
for dependency in term.dependencies:
|
||||
self._add_to_graph(dependency, parents)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
"""
|
||||
Dict mapping names to designated output terms.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
def ordered(self):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in `self`.
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
|
||||
@lazyval
|
||||
def jpeg(self):
|
||||
return display_graph(self, 'jpeg')
|
||||
|
||||
@lazyval
|
||||
def png(self):
|
||||
return display_graph(self, 'png')
|
||||
|
||||
@lazyval
|
||||
def svg(self):
|
||||
return display_graph(self, 'svg')
|
||||
|
||||
def _repr_png_(self):
|
||||
return self.png.data
|
||||
|
||||
|
||||
class ExecutionPlan(TermGraph):
|
||||
"""
|
||||
Graph represention of Pipeline Term dependencies that includes metadata
|
||||
about extra rows required to perform computations.
|
||||
|
||||
Each node in the graph has an `extra_rows` attribute, indicating how many,
|
||||
if any, extra rows we should compute for the node. Extra rows are most
|
||||
@@ -30,6 +135,8 @@ class TermGraph(DiGraph):
|
||||
----------
|
||||
terms : dict
|
||||
A dict mapping names to final output terms.
|
||||
all_dates : pd.DatetimeIndex
|
||||
The dates fo
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -42,21 +149,58 @@ class TermGraph(DiGraph):
|
||||
ordered()
|
||||
Return a topologically-sorted iterator over the terms in self.
|
||||
"""
|
||||
def __init__(self, terms):
|
||||
super(TermGraph, self).__init__(self)
|
||||
def __init__(self,
|
||||
terms,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=0):
|
||||
super(ExecutionPlan, self).__init__(terms)
|
||||
|
||||
self._frozen = False
|
||||
parents = set()
|
||||
for term in itervalues(terms):
|
||||
self._add_to_graph(term, parents, extra_rows=0)
|
||||
# No parents should be left between top-level terms.
|
||||
assert not parents
|
||||
for term in terms.values():
|
||||
self.set_extra_rows(
|
||||
term,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=min_extra_rows,
|
||||
)
|
||||
|
||||
self._outputs = terms
|
||||
self._ordered = topological_sort(self)
|
||||
def set_extra_rows(self,
|
||||
term,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows):
|
||||
"""
|
||||
Compute ``extra_rows`` for transitive dependencies of ``root_terms``
|
||||
"""
|
||||
# A term can require that additional extra rows beyond the minimum be
|
||||
# computed. This is most often used with downsampled terms, which need
|
||||
# to ensure that the first date is a computation date.
|
||||
extra_rows_for_term = term.compute_extra_rows(
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows,
|
||||
)
|
||||
if extra_rows_for_term < min_extra_rows:
|
||||
raise ValueError(
|
||||
"term %s requested fewer rows than the minimum of %d" % (
|
||||
term, min_extra_rows,
|
||||
)
|
||||
)
|
||||
|
||||
# Mark that no more terms should be added to the graph.
|
||||
self._frozen = True
|
||||
self._ensure_extra_rows(term, extra_rows_for_term)
|
||||
|
||||
for dependency, additional_extra_rows in term.dependencies.items():
|
||||
self.set_extra_rows(
|
||||
dependency,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows=extra_rows_for_term + additional_extra_rows,
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def offset(self):
|
||||
@@ -175,71 +319,9 @@ class TermGraph(DiGraph):
|
||||
for term, attrs in iteritems(self.node)
|
||||
}
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
"""
|
||||
Dict mapping names to designated output terms.
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
def ordered(self):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in `self`.
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
|
||||
def _add_to_graph(self, term, parents, extra_rows):
|
||||
"""
|
||||
Add `term` and all its inputs to the graph.
|
||||
"""
|
||||
if self._frozen:
|
||||
raise ValueError("Can't mutate `TermGraph` after construction.")
|
||||
# If we've seen this node already as a parent of the current traversal,
|
||||
# it means we have an unsatisifiable dependency. This should only be
|
||||
# possible if the term's inputs are mutated after construction.
|
||||
if term in parents:
|
||||
raise CyclicDependency(term)
|
||||
parents.add(term)
|
||||
|
||||
# Idempotent if term is already in the graph.
|
||||
self.add_node(term)
|
||||
|
||||
# Make sure we're going to compute at least `extra_rows` of `term`.
|
||||
self._ensure_extra_rows(term, extra_rows)
|
||||
|
||||
# Recursively add dependencies.
|
||||
for dependency, additional_extra_rows in term.dependencies.items():
|
||||
self._add_to_graph(
|
||||
dependency,
|
||||
parents,
|
||||
extra_rows=extra_rows + additional_extra_rows,
|
||||
)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
def _ensure_extra_rows(self, term, N):
|
||||
"""
|
||||
Ensure that we're going to compute at least N extra rows of `term`.
|
||||
"""
|
||||
attrs = self.node[term]
|
||||
attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0))
|
||||
|
||||
@lazyval
|
||||
def jpeg(self):
|
||||
return display_graph(self, 'jpeg')
|
||||
|
||||
@lazyval
|
||||
def png(self):
|
||||
return display_graph(self, 'png')
|
||||
|
||||
@lazyval
|
||||
def svg(self):
|
||||
return display_graph(self, 'svg')
|
||||
|
||||
def _repr_png_(self):
|
||||
return self.png.data
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from zipline.errors import UnsupportedPipelineOutput
|
||||
from zipline.utils.input_validation import expect_types, optional
|
||||
|
||||
from .term import AssetExists, ComputableTerm, Term
|
||||
from .graph import ExecutionPlan, TermGraph
|
||||
from .filters import Filter
|
||||
from .graph import TermGraph
|
||||
from .term import AssetExists, ComputableTerm, Term
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
@@ -148,9 +148,39 @@ class Pipeline(object):
|
||||
)
|
||||
self._screen = screen
|
||||
|
||||
def to_graph(self, screen_name, default_screen):
|
||||
def to_execution_plan(self,
|
||||
screen_name,
|
||||
default_screen,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date):
|
||||
"""
|
||||
Compile into a TermGraph.
|
||||
Compile into an ExecutionPlan.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
screen_name : str
|
||||
Name to supply for self.screen.
|
||||
default_screen : zipline.pipeline.term.Term
|
||||
Term to use as a screen if self.screen is None.
|
||||
all_dates : pd.DatetimeIndex
|
||||
A calendar of dates to use to calculate starts and ends for each
|
||||
term.
|
||||
start_date : pd.Timestamp
|
||||
The first date of requested output.
|
||||
end_date : pd.Timestamp
|
||||
The last date of requested output.
|
||||
"""
|
||||
return ExecutionPlan(
|
||||
self._prepare_graph_terms(screen_name, default_screen),
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
def to_simple_graph(self, screen_name, default_screen):
|
||||
"""
|
||||
Compile into a simple TermGraph with no extra row metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -159,13 +189,16 @@ class Pipeline(object):
|
||||
default_screen : zipline.pipeline.term.Term
|
||||
Term to use as a screen if self.screen is None.
|
||||
"""
|
||||
return TermGraph(self._prepare_graph_terms())
|
||||
|
||||
def _prepare_graph_terms(self, screen_name, default_screen):
|
||||
"""Helper for to_graph and to_execution_plan."""
|
||||
columns = self.columns.copy()
|
||||
screen = self.screen
|
||||
if screen is None:
|
||||
screen = default_screen
|
||||
columns[screen_name] = screen
|
||||
|
||||
return TermGraph(columns)
|
||||
return columns
|
||||
|
||||
def show_graph(self, format='svg'):
|
||||
"""
|
||||
@@ -176,7 +209,7 @@ class Pipeline(object):
|
||||
format : {'svg', 'png', 'jpeg'}
|
||||
Image format to render with. Default is 'svg'.
|
||||
"""
|
||||
g = self.to_graph('', AssetExists())
|
||||
g = self.to_simple_graph('', AssetExists())
|
||||
if format == 'svg':
|
||||
return g.svg
|
||||
elif format == 'png':
|
||||
|
||||
@@ -34,10 +34,10 @@ from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
categorical_dtype,
|
||||
datetime64ns_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
)
|
||||
|
||||
from .mixins import SingleInputMixin
|
||||
from .sentinels import NotSpecified
|
||||
|
||||
|
||||
@@ -279,6 +279,20 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
# call super().
|
||||
self._subclass_called_super_validate = True
|
||||
|
||||
def compute_extra_rows(self,
|
||||
all_dates,
|
||||
start_date,
|
||||
end_date,
|
||||
min_extra_rows):
|
||||
"""
|
||||
Calculate the number of extra rows needed to compute ``self``.
|
||||
|
||||
Must return at least ``min_extra_rows``, but can optionally require
|
||||
more. This is used by downsampled terms to ensure that the first date
|
||||
computed is a recomputation date.
|
||||
"""
|
||||
return min_extra_rows
|
||||
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
"""
|
||||
@@ -337,6 +351,38 @@ class AssetExists(Term):
|
||||
def __repr__(self):
|
||||
return "AssetExists()"
|
||||
|
||||
def _compute(self, today, assets, out):
|
||||
raise NotImplementedError(
|
||||
"AssetExists cannot be computed directly."
|
||||
" Check your PipelineEngine configuration."
|
||||
)
|
||||
|
||||
|
||||
class InputDates(Term):
|
||||
"""
|
||||
1-Dimensional term providing date labels for other term inputs.
|
||||
|
||||
This term is guaranteed to be available as an input for any term computed
|
||||
by SimplePipelineEngine.run_pipeline().
|
||||
"""
|
||||
ndim = 1
|
||||
dataset = None
|
||||
dtype = datetime64ns_dtype
|
||||
inputs = ()
|
||||
dependencies = {}
|
||||
mask = None
|
||||
windowed = False
|
||||
window_safe = True
|
||||
|
||||
def __repr__(self):
|
||||
return "InputDates()"
|
||||
|
||||
def _compute(self, today, assets, out):
|
||||
raise NotImplementedError(
|
||||
"InputDates cannot be computed directly."
|
||||
" Check your PipelineEngine configuration."
|
||||
)
|
||||
|
||||
|
||||
class LoadableTerm(Term):
|
||||
"""
|
||||
@@ -530,7 +576,7 @@ class ComputableTerm(Term):
|
||||
)
|
||||
|
||||
|
||||
class Slice(ComputableTerm, SingleInputMixin):
|
||||
class Slice(ComputableTerm):
|
||||
"""
|
||||
Term for extracting a single column of a another term's output.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user