diff --git a/tests/pipeline/base.py b/tests/pipeline/base.py index a5c8454f..98de5aee 100644 --- a/tests/pipeline/base.py +++ b/tests/pipeline/base.py @@ -1,24 +1,23 @@ """ -Base class for Pipeline API unittests. +Base class for Pipeline API unit tests. """ from functools import wraps import numpy as np from numpy import arange, prod -from pandas import date_range, Int64Index, DataFrame +from pandas import DataFrame, Timestamp from six import iteritems -from zipline.assets.synthetic import make_simple_equity_info from zipline.pipeline.engine import SimplePipelineEngine -from zipline.pipeline import TermGraph -from zipline.pipeline.term import AssetExists +from zipline.pipeline import ExecutionPlan +from zipline.pipeline.term import AssetExists, InputDates from zipline.testing import ( check_arrays, ExplodingObject, - tmp_asset_finder, ) from zipline.testing.fixtures import ( - WithTradingCalendars, + WithAssetFinder, + WithTradingSessions, ZiplineTestCase, ) @@ -53,32 +52,26 @@ def with_defaults(**default_funcs): with_default_shape = with_defaults(shape=lambda self: self.default_shape) -class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase): +class BasePipelineTestCase(WithTradingSessions, + WithAssetFinder, + ZiplineTestCase): + START_DATE = Timestamp('2014', tz='UTC') + END_DATE = Timestamp('2014-12-31', tz='UTC') + ASSET_FINDER_EQUITY_SIDS = list(range(20)) @classmethod def init_class_fixtures(cls): super(BasePipelineTestCase, cls).init_class_fixtures() - cls.__calendar = date_range('2014', '2015', - freq=cls.trading_calendar.day) - cls.__assets = assets = Int64Index(arange(1, 20)) - cls.__tmp_finder_ctx = tmp_asset_finder( - equities=make_simple_equity_info( - assets, - cls.__calendar[0], - cls.__calendar[-1], - ) - ) - cls.__finder = cls.__tmp_finder_ctx.__enter__() - cls.__mask = cls.__finder.lifetimes( - cls.__calendar[-30:], + cls.default_asset_exists_mask = cls.asset_finder.lifetimes( + cls.nyse_sessions[-30:], include_start_date=False, ) @property def default_shape(self): """Default shape for methods that build test data.""" - return self.__mask.shape + return self.default_asset_exists_mask.shape def run_graph(self, graph, initial_workspace, mask=None): """ @@ -103,14 +96,17 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase): """ engine = SimplePipelineEngine( lambda column: ExplodingObject(), - self.__calendar, - self.__finder, + self.nyse_sessions, + self.asset_finder, ) if mask is None: - mask = self.__mask + mask = self.default_asset_exists_mask dates, assets, mask_values = explode(mask) + initial_workspace.setdefault(AssetExists(), mask_values) + initial_workspace.setdefault(InputDates(), dates) + return engine.compute_chunk( graph, dates, @@ -118,15 +114,29 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase): initial_workspace, ) - def check_terms(self, terms, expected, initial_workspace, mask): + def check_terms(self, + terms, + expected, + initial_workspace, + mask, + check=check_arrays): """ Compile the given terms into a TermGraph, compute it with initial_workspace, and compare the results with ``expected``. """ - graph = TermGraph(terms) + start_date, end_date = mask.index[[0, -1]] + graph = ExecutionPlan( + terms, + all_dates=self.nyse_sessions, + start_date=start_date, + end_date=end_date, + ) + results = self.run_graph(graph, initial_workspace, mask) for key, (res, exp) in dzip_exact(results, expected).items(): - check_arrays(res, exp) + check(res, exp) + + return results def build_mask(self, array): """ @@ -138,13 +148,13 @@ class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase): array, # Use the **last** N dates rather than the first N so that we have # space for lookbacks. - index=self.__calendar[-ndates:], - columns=self.__assets[:nassets], + index=self.nyse_sessions[-ndates:], + columns=self.ASSET_FINDER_EQUITY_SIDS[:nassets], dtype=bool, ) @with_default_shape - def arange_data(self, shape, dtype=float): + def arange_data(self, shape, dtype=np.float64): """ Build a block of testing data from numpy.arange. """ diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index 0e2d2197..fecebd93 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -40,6 +40,7 @@ from six import iteritems, itervalues from toolz import merge from zipline.assets.synthetic import make_rotating_equity_info +from zipline.errors import NoFurtherDataError from zipline.lib.adjustment import MULTIPLY from zipline.lib.labelarray import LabelArray from zipline.pipeline import CustomFactor, Pipeline @@ -65,6 +66,7 @@ from zipline.pipeline.loaders.synthetic import ( expected_bar_values_2d, ) from zipline.pipeline.sentinels import NotSpecified +from zipline.pipeline.term import InputDates from zipline.testing import ( AssetID, AssetIDPlusDay, @@ -81,7 +83,7 @@ from zipline.testing.fixtures import ( ZiplineTestCase, ) from zipline.utils.memoize import lazyval -from zipline.utils.numpy_utils import bool_dtype +from zipline.utils.numpy_utils import bool_dtype, datetime64ns_dtype class RollingSumDifference(CustomFactor): @@ -229,6 +231,31 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase): with self.assertRaises(NoFurtherDataError) as e: engine.run_pipeline(p, self.dates[8], self.dates[8]) + def test_input_dates_provided_by_default(self): + loader = self.loader + engine = SimplePipelineEngine( + lambda column: loader, self.dates, self.asset_finder, + ) + + class TestFactor(CustomFactor): + inputs = [InputDates(), USEquityPricing.close] + window_length = 10 + dtype = datetime64ns_dtype + + def compute(self, today, assets, out, dates, closes): + first, last = dates[[0, -1], 0] + assert last == today.asm8 + assert len(dates) == len(closes) == self.window_length + out[:] = first + + p = Pipeline(columns={'t': TestFactor()}) + results = engine.run_pipeline(p, self.dates[9], self.dates[10]) + + # All results are the same, so just grab one column. + column = results.unstack().iloc[:, 0].values + check_arrays(column, self.dates[:2].values) + + def test_same_day_pipeline(self): loader = self.loader engine = SimplePipelineEngine( diff --git a/tests/pipeline/test_factor.py b/tests/pipeline/test_factor.py index 4c63429d..ba50fdae 100644 --- a/tests/pipeline/test_factor.py +++ b/tests/pipeline/test_factor.py @@ -123,20 +123,18 @@ class FactorTestCase(BasePipelineTestCase): data = arange(25).reshape(5, 5) data[eye(5, dtype=bool)] = custom_missing_value - graph = TermGraph( + self.check_terms( { 'isnull': factor.isnull(), 'notnull': factor.notnull(), - } - ) - - results = self.run_graph( - graph, + }, + { + 'isnull': eye(5, dtype=bool), + 'notnull': ~eye(5, dtype=bool), + }, initial_workspace={factor: data}, mask=self.build_mask(ones((5, 5))), ) - check_arrays(results['isnull'], eye(5, dtype=bool)) - check_arrays(results['notnull'], ~eye(5, dtype=bool)) def test_isnull_datetime_dtype(self): class DatetimeFactor(Factor): @@ -149,20 +147,18 @@ class FactorTestCase(BasePipelineTestCase): data = arange(25).reshape(5, 5).astype('datetime64[ns]') data[eye(5, dtype=bool)] = NaTns - graph = TermGraph( + self.check_terms( { 'isnull': factor.isnull(), 'notnull': factor.notnull(), - } - ) - - results = self.run_graph( - graph, + }, + { + 'isnull': eye(5, dtype=bool), + 'notnull': ~eye(5, dtype=bool), + }, initial_workspace={factor: data}, mask=self.build_mask(ones((5, 5))), ) - check_arrays(results['isnull'], eye(5, dtype=bool)) - check_arrays(results['notnull'], ~eye(5, dtype=bool)) @for_each_factor_dtype def test_rank_ascending(self, name, factor_dtype): @@ -206,14 +202,12 @@ class FactorTestCase(BasePipelineTestCase): } def check(terms): - graph = TermGraph(terms) - results = self.run_graph( - graph, + self.check_terms( + terms, + expected={name: expected_ranks[name] for name in terms}, initial_workspace={f: data}, mask=self.build_mask(ones((5, 5))), ) - for method in terms: - check_arrays(results[method], expected_ranks[method]) check({meth: f.rank(method=meth) for meth in expected_ranks}) check({ @@ -265,14 +259,12 @@ class FactorTestCase(BasePipelineTestCase): } def check(terms): - graph = TermGraph(terms) - results = self.run_graph( - graph, + self.check_terms( + terms, + expected={name: expected_ranks[name] for name in terms}, initial_workspace={f: data}, mask=self.build_mask(ones((5, 5))), ) - for method in terms: - check_arrays(results[method], expected_ranks[method]) check({ meth: f.rank(method=meth, ascending=False) @@ -294,14 +286,12 @@ class FactorTestCase(BasePipelineTestCase): mask_data = ~eye(5, dtype=bool) initial_workspace = {f: data, Mask(): mask_data} - graph = TermGraph( - { - "ascending_nomask": f.rank(ascending=True), - "ascending_mask": f.rank(ascending=True, mask=Mask()), - "descending_nomask": f.rank(ascending=False), - "descending_mask": f.rank(ascending=False, mask=Mask()), - } - ) + terms = { + "ascending_nomask": f.rank(ascending=True), + "ascending_mask": f.rank(ascending=True, mask=Mask()), + "descending_nomask": f.rank(ascending=False), + "descending_mask": f.rank(ascending=False, mask=Mask()), + } expected = { "ascending_nomask": array([[1., 3., 4., 5., 2.], @@ -328,13 +318,12 @@ class FactorTestCase(BasePipelineTestCase): [4., 3., 2., 1., nan]]), } - results = self.run_graph( - graph, + self.check_terms( + terms, + expected, initial_workspace, mask=self.build_mask(ones((5, 5))), ) - for method in results: - check_arrays(expected[method], results[method]) @for_each_factor_dtype def test_grouped_rank_ascending(self, name, factor_dtype=float64_dtype): @@ -363,7 +352,7 @@ class FactorTestCase(BasePipelineTestCase): missing_value=None, ) - expected_grouped_ranks = { + expected_ranks = { 'ordinal': array( [[1., 1., 3., 2., 2.], [1., 2., 3., 1., 2.], @@ -402,9 +391,9 @@ class FactorTestCase(BasePipelineTestCase): } def check(terms): - graph = TermGraph(terms) - results = self.run_graph( - graph, + self.check_terms( + terms, + expected={name: expected_ranks[name] for name in terms}, initial_workspace={ f: data, c: classifier_data, @@ -413,25 +402,22 @@ class FactorTestCase(BasePipelineTestCase): mask=self.build_mask(ones((5, 5))), ) - for method in terms: - check_arrays(results[method], expected_grouped_ranks[method]) - # Not specifying the value of ascending param should default to True check({ meth: f.rank(method=meth, groupby=c) - for meth in expected_grouped_ranks + for meth in expected_ranks }) check({ meth: f.rank(method=meth, groupby=str_c) - for meth in expected_grouped_ranks + for meth in expected_ranks }) check({ meth: f.rank(method=meth, groupby=c, ascending=True) - for meth in expected_grouped_ranks + for meth in expected_ranks }) check({ meth: f.rank(method=meth, groupby=str_c, ascending=True) - for meth in expected_grouped_ranks + for meth in expected_ranks }) # Not passing a method should default to ordinal @@ -468,7 +454,7 @@ class FactorTestCase(BasePipelineTestCase): missing_value=None, ) - expected_grouped_ranks = { + expected_ranks = { 'ordinal': array( [[2., 2., 1., 1., 3.], [2., 1., 1., 2., 3.], @@ -507,9 +493,9 @@ class FactorTestCase(BasePipelineTestCase): } def check(terms): - graph = TermGraph(terms) - results = self.run_graph( - graph, + self.check_terms( + terms, + expected={name: expected_ranks[name] for name in terms}, initial_workspace={ f: data, c: classifier_data, @@ -518,16 +504,13 @@ class FactorTestCase(BasePipelineTestCase): mask=self.build_mask(ones((5, 5))), ) - for method in terms: - check_arrays(results[method], expected_grouped_ranks[method]) - check({ meth: f.rank(method=meth, groupby=c, ascending=False) - for meth in expected_grouped_ranks + for meth in expected_ranks }) check({ meth: f.rank(method=meth, groupby=str_c, ascending=False) - for meth in expected_grouped_ranks + for meth in expected_ranks }) # Not passing a method should default to ordinal @@ -707,9 +690,9 @@ class FactorTestCase(BasePipelineTestCase): expected['grouped_str'] = expected['grouped'] expected['grouped_masked_str'] = expected['grouped_masked'] - graph = TermGraph(terms) - results = self.run_graph( - graph, + self.check_terms( + terms, + expected, initial_workspace={ f: factor_data, c: classifier_data, @@ -717,20 +700,13 @@ class FactorTestCase(BasePipelineTestCase): m: filter_data, }, mask=self.build_mask(self.ones_mask(shape=factor_data.shape)), + # The hand-computed values aren't very precise (in particular, + # we truncate repeating decimals at 3 places) This is just + # asserting that the example isn't misleading by being totally + # wrong. + check=partial(check_allclose, atol=0.001), ) - for key, (res, exp) in dzip_exact(results, expected).items(): - check_allclose( - res, - exp, - # The hand-computed values aren't very precise (in particular, - # we truncate repeating decimals at 3 places) This is just - # asserting that the example isn't misleading by being totally - # wrong. - atol=0.001, - err_msg="Mismatch for %r" % key - ) - @parameter_space( seed_value=range(1, 2), normalizer_name_and_func=[ diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index 1861dee8..958c6e90 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -12,7 +12,6 @@ from numpy import ( array, eye, float64, - full_like, full, inf, isfinite, @@ -127,7 +126,6 @@ class FilterTestCase(BasePipelineTestCase): nan_data[:, 0] = nan mask = Mask() - workspace = {self.f: data, mask: mask_data} methods = ['top', 'bottom'] counts = 2, 3, 10 @@ -136,18 +134,6 @@ class FilterTestCase(BasePipelineTestCase): def termname(method, count, masked): return '_'.join([method, str(count), 'mask' if masked else '']) - # Add a term for each permutation of top/bottom, count, and - # mask/no_mask. - terms = {} - for method, count, masked in term_combos: - kwargs = {'N': count} - if masked: - kwargs['mask'] = mask - term = getattr(self.f, method)(**kwargs) - terms[termname(method, count, masked)] = term - - results = self.run_graph(TermGraph(terms), initial_workspace=workspace) - def expected_result(method, count, masked): # Ranking with a mask is equivalent to ranking with nans applied on # the masked values. @@ -158,72 +144,55 @@ class FilterTestCase(BasePipelineTestCase): elif method == 'bottom': return rowwise_rank(to_rank) < count + # Add a term for each permutation of top/bottom, count, and + # mask/no_mask. + terms = {} + expected = {} for method, count, masked in term_combos: - result = results[termname(method, count, masked)] + kwargs = {'N': count} + if masked: + kwargs['mask'] = mask + term = getattr(self.f, method)(**kwargs) + name = termname(method, count, masked) + terms[name] = term + expected[name] = expected_result(method, count, masked) - # Check that `min(c, num_assets)` assets passed each day. - passed_per_day = result.sum(axis=1) - check_arrays( - passed_per_day, - full_like(passed_per_day, min(count, data.shape[1])), - ) - - expected = expected_result(method, count, masked) - check_arrays(result, expected) - - def test_bottom(self): - counts = 2, 3, 10 - data = self.randn_data(seed=5) # Arbitrary seed choice. - results = self.run_graph( - TermGraph( - {'bottom_' + str(c): self.f.bottom(c) for c in counts} - ), - initial_workspace={self.f: data}, + self.check_terms( + terms, + expected, + initial_workspace={self.f: data, mask: mask_data}, + mask=self.build_mask(self.ones_mask()), ) - for c in counts: - result = results['bottom_' + str(c)] - - # Check that `min(c, num_assets)` assets passed each day. - passed_per_day = result.sum(axis=1) - check_arrays( - passed_per_day, - full_like(passed_per_day, min(c, data.shape[1])), - ) - - # Check that the bottom `c` assets passed. - expected = rowwise_rank(data) < c - check_arrays(result, expected) def test_percentile_between(self): quintiles = range(5) filter_names = ['pct_' + str(q) for q in quintiles] iter_quintiles = zip(filter_names, quintiles) - - graph = TermGraph( - { - name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0) - for name, q in zip(filter_names, quintiles) - } - ) + terms = { + name: self.f.percentile_between(q * 20.0, (q + 1) * 20.0) + for name, q in iter_quintiles + } # Test with 5 columns and no NaNs. eye5 = eye(5, dtype=float64) - results = self.run_graph( - graph, - initial_workspace={self.f: eye5}, - mask=self.build_mask(ones((5, 5))), - ) + expected = {} for name, quintile in iter_quintiles: - result = results[name] if quintile < 4: # There are four 0s and one 1 in each row, so the first 4 # quintiles should be all the locations with zeros in the input # array. - check_arrays(result, ~eye5.astype(bool)) + expected[name] = ~eye5.astype(bool) else: # The top quintile should match the sole 1 in each row. - check_arrays(result, eye5.astype(bool)) + expected[name] = eye5.astype(bool) + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace={self.f: eye5}, + mask=self.build_mask(ones((5, 5))), + ) # Test with 6 columns, no NaNs, and one masked entry per day. eye6 = eye(6, dtype=float64) @@ -233,41 +202,44 @@ class FilterTestCase(BasePipelineTestCase): [1, 1, 0, 1, 1, 1], [1, 1, 1, 0, 1, 1], [1, 1, 1, 1, 0, 1]], dtype=bool) - - results = self.run_graph( - graph, - initial_workspace={self.f: eye6}, - mask=self.build_mask(mask) - ) + expected = {} for name, quintile in iter_quintiles: - result = results[name] if quintile < 4: # Should keep all values that were 0 in the base data and were # 1 in the mask. - check_arrays(result, mask & (~eye6.astype(bool))), + expected[name] = mask & ~eye6.astype(bool) else: - # Should keep all the 1s in the base data. - check_arrays(result, eye6.astype(bool)) + # The top quintile should match the sole 1 in each row. + expected[name] = eye6.astype(bool) + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace={self.f: eye6}, + mask=self.build_mask(mask), + ) # Test with 6 columns, no mask, and one NaN per day. Should have the # same outcome as if we had masked the NaNs. # In particular, the NaNs should never pass any filters. eye6_withnans = eye6.copy() putmask(eye6_withnans, ~mask, nan) - results = self.run_graph( - graph, - initial_workspace={self.f: eye6}, - mask=self.build_mask(mask) - ) + expected = {} for name, quintile in iter_quintiles: - result = results[name] if quintile < 4: # Should keep all values that were 0 in the base data and were # 1 in the mask. - check_arrays(result, mask & (~eye6.astype(bool))), + expected[name] = mask & (~eye6.astype(bool)) else: # Should keep all the 1s in the base data. - check_arrays(result, eye6.astype(bool)) + expected[name] = eye6.astype(bool) + + self.check_terms( + terms, + expected, + initial_workspace={self.f: eye6}, + mask=self.build_mask(mask), + ) def test_percentile_nasty_partitions(self): # Test percentile with nasty partitions: divide up 5 assets into @@ -281,27 +253,26 @@ class FilterTestCase(BasePipelineTestCase): quartiles = range(4) filter_names = ['pct_' + str(q) for q in quartiles] - graph = TermGraph( - { - name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0) - for name, q in zip(filter_names, quartiles) - } - ) - results = self.run_graph( - graph, - initial_workspace={self.f: data}, - mask=self.build_mask(ones((5, 5))), - ) + terms = { + name: self.f.percentile_between(q * 25.0, (q + 1) * 25.0) + for name, q in zip(filter_names, quartiles) + } + expected = {} for name, quartile in zip(filter_names, quartiles): - result = results[name] lower = quartile * 25.0 upper = (quartile + 1) * 25.0 - expected = and_( + expected[name] = and_( nanpercentile(data, lower, axis=1, keepdims=True) <= data, data <= nanpercentile(data, upper, axis=1, keepdims=True), ) - check_arrays(result, expected) + + self.check_terms( + terms, + expected, + initial_workspace={self.f: data}, + mask=self.build_mask(ones((5, 5))), + ) def test_percentile_after_mask(self): f_input = eye(5) @@ -312,77 +283,79 @@ class FilterTestCase(BasePipelineTestCase): without_mask = self.g.percentile_between(80, 100) with_mask = self.g.percentile_between(80, 100, mask=custom_mask) - graph = TermGraph( - { - 'custom_mask': custom_mask, - 'without': without_mask, - 'with': with_mask, - } - ) + terms = { + 'mask': custom_mask, + 'without_mask': without_mask, + 'with_mask': with_mask, + } + expected = { + # Mask that accepts everything except the diagonal. + 'mask': ~eye(5, dtype=bool), + # Second should pass the largest value each day. Each row is + # strictly increasing, so we always select the last value. + 'without_mask': array( + [[0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1]], + dtype=bool, + ), + # With a mask, we should remove the diagonal as an option before + # computing percentiles. On the last day, we should get the + # second-largest value, rather than the largest. + 'with_mask': array( + [[0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 1, 0]], # Different from with! + dtype=bool, + ), + } - results = self.run_graph( - graph, + self.check_terms( + terms, + expected, initial_workspace={self.f: f_input, self.g: g_input}, mask=initial_mask, ) - # First should pass everything but the diagonal. - check_arrays(results['custom_mask'], ~eye(5, dtype=bool)) - - # Second should pass the largest value each day. Each row is strictly - # increasing, so we always select the last value. - expected_without = array( - [[0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 1]], - dtype=bool, - ) - check_arrays(results['without'], expected_without) - - # When sequencing, we should remove the diagonal as an option before - # computing percentiles. On the last day, we should get the - # second-largest value, rather than the largest. - expected_with = array( - [[0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 1], - [0, 0, 0, 1, 0]], # Different from previous! - dtype=bool, - ) - check_arrays(results['with'], expected_with) - def test_isnan(self): data = self.randn_data(seed=10) diag = eye(*data.shape, dtype=bool) data[diag] = nan - results = self.run_graph( - TermGraph({ + self.check_terms( + terms={ 'isnan': self.f.isnan(), 'isnull': self.f.isnull(), - }), + }, + expected={ + 'isnan': diag, + 'isnull': diag, + }, initial_workspace={self.f: data}, + mask=self.build_mask(self.ones_mask()), ) - check_arrays(results['isnan'], diag) - check_arrays(results['isnull'], diag) def test_notnan(self): data = self.randn_data(seed=10) diag = eye(*data.shape, dtype=bool) data[diag] = nan - results = self.run_graph( - TermGraph({ + self.check_terms( + terms={ 'notnan': self.f.notnan(), 'notnull': self.f.notnull(), - }), + }, + expected={ + 'notnan': ~diag, + 'notnull': ~diag, + }, initial_workspace={self.f: data}, + mask=self.build_mask(self.ones_mask()), ) - check_arrays(results['notnan'], ~diag) - check_arrays(results['notnull'], ~diag) def test_isfinite(self): data = self.randn_data(seed=10) @@ -390,11 +363,12 @@ class FilterTestCase(BasePipelineTestCase): data[:, 2] = inf data[:, 4] = -inf - results = self.run_graph( - TermGraph({'isfinite': self.f.isfinite()}), + self.check_terms( + terms={'isfinite': self.f.isfinite()}, + expected={'isfinite': isfinite(data)}, initial_workspace={self.f: data}, + mask=self.build_mask(self.ones_mask()), ) - check_arrays(results['isfinite'], isfinite(data)) def test_all(self): @@ -427,18 +401,19 @@ class FilterTestCase(BasePipelineTestCase): inputs = () window_length = 0 - results = self.run_graph( - TermGraph({ + self.check_terms( + terms={ '3': All(inputs=[Input()], window_length=3), '4': All(inputs=[Input()], window_length=4), - }), + }, + expected={ + '3': expected_3, + '4': expected_4, + }, initial_workspace={Input(): data}, mask=self.build_mask(ones(shape=data.shape)), ) - check_arrays(results['3'], expected_3) - check_arrays(results['4'], expected_4) - def test_any(self): # FUN FACT: The inputs and outputs here are exactly the negation of @@ -486,18 +461,19 @@ class FilterTestCase(BasePipelineTestCase): inputs = () window_length = 0 - results = self.run_graph( - TermGraph({ + self.check_terms( + terms={ '3': Any(inputs=[Input()], window_length=3), '4': Any(inputs=[Input()], window_length=4), - }), + }, + expected={ + '3': expected_3, + '4': expected_4, + }, initial_workspace={Input(): data}, mask=self.build_mask(ones(shape=data.shape)), ) - check_arrays(results['3'], expected_3) - check_arrays(results['4'], expected_4) - def test_at_least_N(self): # With a window_length of K, AtLeastN should return 1 @@ -553,26 +529,27 @@ class FilterTestCase(BasePipelineTestCase): window_length=4, N=4) - results = self.run_graph( - TermGraph({ + self.check_terms( + terms={ 'AllButOne': all_but_one, 'AllButTwo': all_but_two, 'AnyEquiv': any_equiv, 'AllEquiv': all_equiv, 'Any': Any(inputs=[Input()], window_length=4), 'All': All(inputs=[Input()], window_length=4) - }), + }, + expected={ + 'Any': expected_1, + 'AnyEquiv': expected_1, + 'AllButTwo': expected_2, + 'AllButOne': expected_3, + 'All': expected_4, + 'AllEquiv': expected_4, + }, initial_workspace={Input(): data}, mask=self.build_mask(ones(shape=data.shape)), ) - check_arrays(results['Any'], expected_1) - check_arrays(results['AnyEquiv'], expected_1) - check_arrays(results['AllButTwo'], expected_2) - check_arrays(results['AllButOne'], expected_3) - check_arrays(results['All'], expected_4) - check_arrays(results['AllEquiv'], expected_4) - @parameter_space(factor_len=[2, 3, 4]) def test_window_safe(self, factor_len): # all true data set of (days, securities) @@ -591,19 +568,19 @@ class FilterTestCase(BasePipelineTestCase): # sum for each column out[:] = np_sum(filter_, axis=0) - results = self.run_graph( - TermGraph({'windowsafe': TestFactor()}), - initial_workspace={InputFilter(): data}, - ) - - # number of days in default_shape n = self.default_shape[0] - - # shape of output array output_shape = ((n - factor_len + 1), self.default_shape[1]) - check_arrays( - results['windowsafe'], - full(output_shape, factor_len, dtype=float64) + full(output_shape, factor_len, dtype=float64) + + self.check_terms( + terms={ + 'windowsafe': TestFactor(), + }, + expected={ + 'windowsafe': full(output_shape, factor_len, dtype=float64), + }, + initial_workspace={InputFilter(): data}, + mask=self.build_mask(self.ones_mask()), ) @parameter_space( diff --git a/tests/pipeline/test_technical.py b/tests/pipeline/test_technical.py index 718cc7fd..505e1f38 100644 --- a/tests/pipeline/test_technical.py +++ b/tests/pipeline/test_technical.py @@ -7,10 +7,7 @@ import pandas as pd import talib from zipline.lib.adjusted_array import AdjustedArray -from zipline.pipeline import TermGraph from zipline.pipeline.data import USEquityPricing -from zipline.pipeline.engine import SimplePipelineEngine -from zipline.pipeline.term import AssetExists from zipline.pipeline.factors import ( BollingerBands, Aroon, @@ -20,61 +17,22 @@ from zipline.pipeline.factors import ( RateOfChangePercentage, TrueRange, ) -from zipline.testing import ExplodingObject, parameter_space -from zipline.testing.fixtures import WithAssetFinder, ZiplineTestCase +from zipline.testing import parameter_space +from zipline.testing.fixtures import ZiplineTestCase from zipline.testing.predicates import assert_equal - -class WithTechnicalFactor(WithAssetFinder): - """ZiplineTestCase fixture for testing technical factors. - """ - ASSET_FINDER_EQUITY_SIDS = tuple(range(5)) - START_DATE = pd.Timestamp('2014-01-01', tz='utc') - - @classmethod - def init_class_fixtures(cls): - super(WithTechnicalFactor, cls).init_class_fixtures() - cls.ndays = ndays = 24 - cls.nassets = nassets = len(cls.ASSET_FINDER_EQUITY_SIDS) - cls.dates = dates = pd.date_range(cls.START_DATE, periods=ndays) - cls.assets = pd.Index(cls.asset_finder.sids) - cls.engine = SimplePipelineEngine( - lambda column: ExplodingObject(), - dates, - cls.asset_finder, - ) - cls.asset_exists = exists = np.full((ndays, nassets), True, dtype=bool) - cls.asset_exists_masked = masked = exists.copy() - masked[:, -1] = False - - def run_graph(self, graph, initial_workspace, mask_sid): - initial_workspace.setdefault( - AssetExists(), - self.asset_exists_masked if mask_sid else self.asset_exists, - ) - return self.engine.compute_chunk( - graph, - self.dates, - self.assets, - initial_workspace, - ) +from .base import BasePipelineTestCase -class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase): - @classmethod - def init_class_fixtures(cls): - super(BollingerBandsTestCase, cls).init_class_fixtures() - cls._closes = closes = ( - np.arange(cls.ndays, dtype=float)[:, np.newaxis] + - np.arange(cls.nassets, dtype=float) * 100 - ) - cls._closes_masked = masked = closes.copy() - masked[:, -1] = np.nan +class BollingerBandsTestCase(BasePipelineTestCase): - def closes(self, masked): - return self._closes_masked if masked else self._closes + def closes(self, mask_last_sid): + data = self.arange_data(dtype=np.float64) + if mask_last_sid: + data[:, -1] = np.nan + return data - def expected(self, window_length, k, closes): + def expected_bbands(self, window_length, k, closes): """Compute the expected data (without adjustments) for the given window, k, and closes array. @@ -83,11 +41,14 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase): lower_cols = [] middle_cols = [] upper_cols = [] - for n in range(self.nassets): + + ndates, nassets = closes.shape + + for n in range(nassets): close_col = closes[:, n] if np.isnan(close_col).all(): # ta-lib doesn't deal well with all nans. - upper, middle, lower = [np.full(self.ndays, np.nan)] * 3 + upper, middle, lower = [np.full(ndates, np.nan)] * 3 else: upper, middle, lower = talib.BBANDS( close_col, @@ -112,38 +73,38 @@ class BollingerBandsTestCase(WithTechnicalFactor, ZiplineTestCase): @parameter_space( window_length={5, 10, 20}, k={1.5, 2, 2.5}, - mask_sid={True, False}, + mask_last_sid={True, False}, + __fail_fast=True, ) - def test_bollinger_bands(self, window_length, k, mask_sid): - closes = self.closes(mask_sid) - result = self.run_graph( - TermGraph({ - 'f': BollingerBands( - window_length=window_length, - k=k, - ), - }), + def test_bollinger_bands(self, window_length, k, mask_last_sid): + closes = self.closes(mask_last_sid=mask_last_sid) + mask = ~np.isnan(closes) + bbands = BollingerBands(window_length=window_length, k=k) + + expected = self.expected_bbands(window_length, k, closes) + + self.check_terms( + terms={ + 'upper': bbands.upper, + 'middle': bbands.middle, + 'lower': bbands.lower, + }, + expected={ + 'upper': expected[0], + 'middle': expected[1], + 'lower': expected[2], + }, initial_workspace={ USEquityPricing.close: AdjustedArray( - closes, - np.full_like(closes, True, dtype=bool), - {}, - np.nan, + data=closes, + mask=mask, + adjustments={}, + missing_value=np.nan, ), }, - mask_sid=mask_sid, - )['f'] - - expected_upper, expected_middle, expected_lower = self.expected( - window_length, - k, - closes, + mask=self.build_mask(mask), ) - assert_equal(result.upper, expected_upper) - assert_equal(result.middle, expected_middle) - assert_equal(result.lower, expected_lower) - def test_bollinger_bands_output_ordering(self): bbands = BollingerBands(window_length=5, k=2) lower, middle, upper = bbands @@ -185,7 +146,7 @@ class AroonTestCase(ZiplineTestCase): assert_equal(out, expected_out) -class TestFastStochasticOscillator(WithTechnicalFactor, ZiplineTestCase): +class TestFastStochasticOscillator(ZiplineTestCase): """ Test the Fast Stochastic Oscillator """ @@ -427,7 +388,7 @@ class TestLinearWeightedMovingAverage(ZiplineTestCase): assert_equal(out, np.array([30., 31., 32., 33., 34.])) -class TestTrueRange(WithTechnicalFactor, ZiplineTestCase): +class TestTrueRange(ZiplineTestCase): def test_tr_basic(self): tr = TrueRange() diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index ecf31ce2..c58dd4e0 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -6,6 +6,7 @@ from itertools import product from unittest import TestCase from toolz import assoc +import pandas as pd from zipline.assets import Asset from zipline.errors import ( @@ -24,7 +25,7 @@ from zipline.pipeline import ( CustomFactor, Factor, Filter, - TermGraph, + ExecutionPlan, ) from zipline.pipeline.data import Column, DataSet from zipline.pipeline.data.testing import TestingDataSet @@ -33,6 +34,7 @@ from zipline.pipeline.factors import RecarrayField from zipline.pipeline.sentinels import NotSpecified from zipline.pipeline.term import AssetExists, Slice from zipline.testing import parameter_space +from zipline.testing.fixtures import WithTradingSessions, ZiplineTestCase from zipline.testing.predicates import ( assert_equal, assert_raises, @@ -152,7 +154,14 @@ def to_dict(l): return dict(zip(map(str, range(len(l))), l)) -class DependencyResolutionTestCase(TestCase): +class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase): + + TRADING_CALENDAR_STRS = ('NYSE',) + START_DATE = pd.Timestamp('2014-01-02', tz='UTC') + END_DATE = pd.Timestamp('2014-12-31', tz='UTC') + + execution_plan_start = pd.Timestamp('2014-06-01', tz='UTC') + execution_plan_end = pd.Timestamp('2014-06-30', tz='UTC') def check_dependency_order(self, ordered_terms): seen = set() @@ -163,6 +172,14 @@ class DependencyResolutionTestCase(TestCase): seen.add(term) + def make_execution_plan(self, terms): + return ExecutionPlan( + terms, + self.nyse_sessions, + self.execution_plan_start, + self.execution_plan_end, + ) + def test_single_factor(self): """ Test dependency resolution for a single factor. @@ -182,7 +199,7 @@ class DependencyResolutionTestCase(TestCase): self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4) for foobar in gen_equivalent_factors(): - check_output(TermGraph(to_dict([foobar]))) + check_output(self.make_execution_plan(to_dict([foobar]))) def test_single_factor_instance_args(self): """ @@ -190,7 +207,9 @@ class DependencyResolutionTestCase(TestCase): the constructor. """ bar, buzz = SomeDataSet.bar, SomeDataSet.buzz - graph = TermGraph(to_dict([SomeFactor([bar, buzz], window_length=5)])) + + factor = SomeFactor([bar, buzz], window_length=5) + graph = self.make_execution_plan(to_dict([factor])) resolution_order = list(graph.ordered()) @@ -214,7 +233,7 @@ class DependencyResolutionTestCase(TestCase): f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar]) f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz]) - graph = TermGraph(to_dict([f1, f2])) + graph = self.make_execution_plan(to_dict([f1, f2])) resolution_order = list(graph.ordered()) # bar should only appear once. diff --git a/zipline/pipeline/__init__.py b/zipline/pipeline/__init__.py index de6d142f..8d2a412f 100644 --- a/zipline/pipeline/__init__.py +++ b/zipline/pipeline/__init__.py @@ -6,7 +6,7 @@ from .engine import SimplePipelineEngine from .factors import Factor, CustomFactor from .filters import Filter, CustomFilter from .term import Term -from .graph import TermGraph +from .graph import ExecutionPlan, TermGraph from .pipeline import Pipeline from .loaders import USEquityPricingLoader @@ -56,6 +56,7 @@ __all__ = ( 'CustomFilter', 'CustomClassifier', 'engine_from_files', + 'ExecutionPlan', 'Factor', 'Filter', 'Pipeline', diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index ffb5b1c9..616178c7 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -21,7 +21,7 @@ from zipline.errors import NoFurtherDataError from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis from zipline.utils.pandas_utils import explode -from .term import AssetExists, LoadableTerm +from .term import AssetExists, InputDates, LoadableTerm class PipelineEngine(with_metaclass(ABCMeta)): @@ -98,6 +98,7 @@ class SimplePipelineEngine(object): '_calendar', '_finder', '_root_mask_term', + '_root_mask_dates_term', '__weakref__', ) @@ -105,7 +106,9 @@ class SimplePipelineEngine(object): self._get_loader = get_loader self._calendar = calendar self._finder = asset_finder + self._root_mask_term = AssetExists() + self._root_mask_dates_term = InputDates() def run_pipeline(self, pipeline, start_date, end_date): """ @@ -161,7 +164,13 @@ class SimplePipelineEngine(object): ) screen_name = uuid4().hex - graph = pipeline.to_graph(screen_name, self._root_mask_term) + graph = pipeline.to_execution_plan( + screen_name, + self._root_mask_term, + self._calendar, + start_date, + end_date, + ) extra_rows = graph.extra_rows[self._root_mask_term] root_mask = self._compute_root_mask(start_date, end_date, extra_rows) dates, assets, root_mask_values = explode(root_mask) @@ -170,7 +179,10 @@ class SimplePipelineEngine(object): graph, dates, assets, - initial_workspace={self._root_mask_term: root_mask_values}, + initial_workspace={ + self._root_mask_term: root_mask_values, + self._root_mask_dates_term: dates.values[:, None], + }, ) return self._to_narrow( diff --git a/zipline/pipeline/graph.py b/zipline/pipeline/graph.py index 510776ad..eedb8c3c 100644 --- a/zipline/pipeline/graph.py +++ b/zipline/pipeline/graph.py @@ -5,7 +5,7 @@ from networkx import ( DiGraph, topological_sort, ) -from six import itervalues, iteritems +from six import iteritems from zipline.utils.memoize import lazyval from zipline.pipeline.visualize import display_graph @@ -18,7 +18,112 @@ class CyclicDependency(Exception): class TermGraph(DiGraph): """ - Graph represention of Pipeline Term dependencies. + An abstract representation of Pipeline Term dependencies. + + This class does not keep any additional metadata about any term relations + other than dependency ordering. As such it is only useful in contexts + where you care exclusively about order properties (for example, when + drawing visualizations of execution order). + + Parameters + ---------- + terms : dict + A dict mapping names to final output terms. + + Attributes + ---------- + outputs + + Methods + ------- + ordered() + Return a topologically-sorted iterator over the terms in self. + + See Also + -------- + ExecutionPlan + """ + def __init__(self, terms): + super(TermGraph, self).__init__() + + self._frozen = False + parents = set() + for name, term in iteritems(terms): + self._add_to_graph(term, parents) + # No parents should be left between top-level terms. + assert not parents + + self._outputs = terms + self._ordered = topological_sort(self) + + # Mark that no more terms should be added to the graph. + self._frozen = True + + def _add_to_graph(self, term, parents): + """ + Add a term and all its children to ``graph``. + + ``parents`` is the set of all the parents of ``term` that we've added + so far. It is only used to detect dependency cycles. + """ + if self._frozen: + raise ValueError( + "Can't mutate %s after construction." % type(self).__name__ + ) + + # If we've seen this node already as a parent of the current traversal, + # it means we have an unsatisifiable dependency. This should only be + # possible if the term's inputs are mutated after construction. + if term in parents: + raise CyclicDependency(term) + + parents.add(term) + + self.add_node(term) + + for dependency in term.dependencies: + self._add_to_graph(dependency, parents) + self.add_edge(dependency, term) + + parents.remove(term) + + @property + def outputs(self): + """ + Dict mapping names to designated output terms. + """ + return self._outputs + + def ordered(self): + """ + Return a topologically-sorted iterator over the terms in `self`. + """ + return iter(self._ordered) + + @lazyval + def loadable_terms(self): + return tuple(term for term in self if isinstance(term, LoadableTerm)) + + @lazyval + def jpeg(self): + return display_graph(self, 'jpeg') + + @lazyval + def png(self): + return display_graph(self, 'png') + + @lazyval + def svg(self): + return display_graph(self, 'svg') + + def _repr_png_(self): + return self.png.data + + +class ExecutionPlan(TermGraph): + """ + Graph represention of Pipeline Term dependencies that includes metadata + about extra rows required to perform computations. Each node in the graph has an `extra_rows` attribute, indicating how many, if any, extra rows we should compute for the node. Extra rows are most @@ -30,6 +135,8 @@ class TermGraph(DiGraph): ---------- terms : dict A dict mapping names to final output terms. + all_dates : pd.DatetimeIndex + The dates fo Attributes ---------- @@ -42,21 +149,58 @@ class TermGraph(DiGraph): ordered() Return a topologically-sorted iterator over the terms in self. """ - def __init__(self, terms): - super(TermGraph, self).__init__(self) + def __init__(self, + terms, + all_dates, + start_date, + end_date, + min_extra_rows=0): + super(ExecutionPlan, self).__init__(terms) - self._frozen = False - parents = set() - for term in itervalues(terms): - self._add_to_graph(term, parents, extra_rows=0) - # No parents should be left between top-level terms. - assert not parents + for term in terms.values(): + self.set_extra_rows( + term, + all_dates, + start_date, + end_date, + min_extra_rows=min_extra_rows, + ) - self._outputs = terms - self._ordered = topological_sort(self) + def set_extra_rows(self, + term, + all_dates, + start_date, + end_date, + min_extra_rows): + """ + Compute ``extra_rows`` for transitive dependencies of ``root_terms`` + """ + # A term can require that additional extra rows beyond the minimum be + # computed. This is most often used with downsampled terms, which need + # to ensure that the first date is a computation date. + extra_rows_for_term = term.compute_extra_rows( + all_dates, + start_date, + end_date, + min_extra_rows, + ) + if extra_rows_for_term < min_extra_rows: + raise ValueError( + "term %s requested fewer rows than the minimum of %d" % ( + term, min_extra_rows, + ) + ) - # Mark that no more terms should be added to the graph. - self._frozen = True + self._ensure_extra_rows(term, extra_rows_for_term) + + for dependency, additional_extra_rows in term.dependencies.items(): + self.set_extra_rows( + dependency, + all_dates, + start_date, + end_date, + min_extra_rows=extra_rows_for_term + additional_extra_rows, + ) @lazyval def offset(self): @@ -175,71 +319,9 @@ class TermGraph(DiGraph): for term, attrs in iteritems(self.node) } - @property - def outputs(self): - """ - Dict mapping names to designated output terms. - """ - return self._outputs - - def ordered(self): - """ - Return a topologically-sorted iterator over the terms in `self`. - """ - return iter(self._ordered) - - @lazyval - def loadable_terms(self): - return tuple(term for term in self if isinstance(term, LoadableTerm)) - - def _add_to_graph(self, term, parents, extra_rows): - """ - Add `term` and all its inputs to the graph. - """ - if self._frozen: - raise ValueError("Can't mutate `TermGraph` after construction.") - # If we've seen this node already as a parent of the current traversal, - # it means we have an unsatisifiable dependency. This should only be - # possible if the term's inputs are mutated after construction. - if term in parents: - raise CyclicDependency(term) - parents.add(term) - - # Idempotent if term is already in the graph. - self.add_node(term) - - # Make sure we're going to compute at least `extra_rows` of `term`. - self._ensure_extra_rows(term, extra_rows) - - # Recursively add dependencies. - for dependency, additional_extra_rows in term.dependencies.items(): - self._add_to_graph( - dependency, - parents, - extra_rows=extra_rows + additional_extra_rows, - ) - self.add_edge(dependency, term) - - parents.remove(term) - def _ensure_extra_rows(self, term, N): """ Ensure that we're going to compute at least N extra rows of `term`. """ attrs = self.node[term] attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0)) - - @lazyval - def jpeg(self): - return display_graph(self, 'jpeg') - - @lazyval - def png(self): - return display_graph(self, 'png') - - @lazyval - def svg(self): - return display_graph(self, 'svg') - - def _repr_png_(self): - return self.png.data diff --git a/zipline/pipeline/pipeline.py b/zipline/pipeline/pipeline.py index 53276b82..631b99d3 100644 --- a/zipline/pipeline/pipeline.py +++ b/zipline/pipeline/pipeline.py @@ -2,9 +2,9 @@ from zipline.errors import UnsupportedPipelineOutput from zipline.utils.input_validation import expect_types, optional -from .term import AssetExists, ComputableTerm, Term +from .graph import ExecutionPlan, TermGraph from .filters import Filter -from .graph import TermGraph +from .term import AssetExists, ComputableTerm, Term class Pipeline(object): @@ -148,9 +148,39 @@ class Pipeline(object): ) self._screen = screen - def to_graph(self, screen_name, default_screen): + def to_execution_plan(self, + screen_name, + default_screen, + all_dates, + start_date, + end_date): """ - Compile into a TermGraph. + Compile into an ExecutionPlan. + + Parameters + ---------- + screen_name : str + Name to supply for self.screen. + default_screen : zipline.pipeline.term.Term + Term to use as a screen if self.screen is None. + all_dates : pd.DatetimeIndex + A calendar of dates to use to calculate starts and ends for each + term. + start_date : pd.Timestamp + The first date of requested output. + end_date : pd.Timestamp + The last date of requested output. + """ + return ExecutionPlan( + self._prepare_graph_terms(screen_name, default_screen), + all_dates, + start_date, + end_date, + ) + + def to_simple_graph(self, screen_name, default_screen): + """ + Compile into a simple TermGraph with no extra row metadata. Parameters ---------- @@ -159,13 +189,16 @@ class Pipeline(object): default_screen : zipline.pipeline.term.Term Term to use as a screen if self.screen is None. """ + return TermGraph(self._prepare_graph_terms()) + + def _prepare_graph_terms(self, screen_name, default_screen): + """Helper for to_graph and to_execution_plan.""" columns = self.columns.copy() screen = self.screen if screen is None: screen = default_screen columns[screen_name] = screen - - return TermGraph(columns) + return columns def show_graph(self, format='svg'): """ @@ -176,7 +209,7 @@ class Pipeline(object): format : {'svg', 'png', 'jpeg'} Image format to render with. Default is 'svg'. """ - g = self.to_graph('', AssetExists()) + g = self.to_simple_graph('', AssetExists()) if format == 'svg': return g.svg elif format == 'png': diff --git a/zipline/pipeline/term.py b/zipline/pipeline/term.py index 5a5dba5b..5c7218e2 100644 --- a/zipline/pipeline/term.py +++ b/zipline/pipeline/term.py @@ -34,10 +34,10 @@ from zipline.utils.memoize import lazyval from zipline.utils.numpy_utils import ( bool_dtype, categorical_dtype, + datetime64ns_dtype, default_missing_value_for_dtype, ) -from .mixins import SingleInputMixin from .sentinels import NotSpecified @@ -279,6 +279,20 @@ class Term(with_metaclass(ABCMeta, object)): # call super(). self._subclass_called_super_validate = True + def compute_extra_rows(self, + all_dates, + start_date, + end_date, + min_extra_rows): + """ + Calculate the number of extra rows needed to compute ``self``. + + Must return at least ``min_extra_rows``, but can optionally require + more. This is used by downsampled terms to ensure that the first date + computed is a recomputation date. + """ + return min_extra_rows + @abstractproperty def inputs(self): """ @@ -337,6 +351,38 @@ class AssetExists(Term): def __repr__(self): return "AssetExists()" + def _compute(self, today, assets, out): + raise NotImplementedError( + "AssetExists cannot be computed directly." + " Check your PipelineEngine configuration." + ) + + +class InputDates(Term): + """ + 1-Dimensional term providing date labels for other term inputs. + + This term is guaranteed to be available as an input for any term computed + by SimplePipelineEngine.run_pipeline(). + """ + ndim = 1 + dataset = None + dtype = datetime64ns_dtype + inputs = () + dependencies = {} + mask = None + windowed = False + window_safe = True + + def __repr__(self): + return "InputDates()" + + def _compute(self, today, assets, out): + raise NotImplementedError( + "InputDates cannot be computed directly." + " Check your PipelineEngine configuration." + ) + class LoadableTerm(Term): """ @@ -530,7 +576,7 @@ class ComputableTerm(Term): ) -class Slice(ComputableTerm, SingleInputMixin): +class Slice(ComputableTerm): """ Term for extracting a single column of a another term's output.