diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index c4562afe..1051cee2 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -26,11 +26,13 @@ from numpy import ( from numpy.random import randn, seed as random_seed from zipline.errors import BadPercentileBounds -from zipline.pipeline import Filter, Factor +from zipline.pipeline import Filter, Factor, Pipeline from zipline.pipeline.classifiers import Classifier from zipline.pipeline.factors import CustomFactor -from zipline.pipeline.filters import All, Any, AtLeastN -from zipline.testing import parameter_space, permute_rows +from zipline.pipeline.filters import All, Any, AtLeastN, SpecificAssets +from zipline.testing import parameter_space, permute_rows, ZiplineTestCase +from zipline.testing.fixtures import WithSeededRandomPipelineEngine +from zipline.testing.predicates import assert_equal from zipline.utils.numpy_utils import float64_dtype, int64_dtype from .base import BasePipelineTestCase, with_default_shape @@ -820,3 +822,40 @@ class FilterTestCase(BasePipelineTestCase): }, mask=self.build_mask(permute(rot90(self.eye_mask(shape=shape)))), ) + + +class SpecificAssetsTestCase(WithSeededRandomPipelineEngine, + ZiplineTestCase): + + ASSET_FINDER_EQUITY_SIDS = tuple(range(10)) + + def test_specific_assets(self): + assets = self.asset_finder.retrieve_all(self.ASSET_FINDER_EQUITY_SIDS) + + class SidFactor(CustomFactor): + """A factor that just returns each asset's sid.""" + inputs = () + window_length = 1 + + def compute(self, today, sids, out): + out[:] = sids + + pipe = Pipeline( + columns={ + 'sid': SidFactor(), + 'evens': SpecificAssets(assets[::2]), + 'odds': SpecificAssets(assets[1::2]), + 'first_five': SpecificAssets(assets[:5]), + 'last_three': SpecificAssets(assets[-3:]), + }, + ) + + start, end = self.trading_days[[-10, -1]] + results = self.run_pipeline(pipe, start, end).unstack() + + sids = results.sid.astype(int64_dtype) + + assert_equal(results.evens, ~(sids % 2).astype(bool)) + assert_equal(results.odds, (sids % 2).astype(bool)) + assert_equal(results.first_five, sids < 5) + assert_equal(results.last_three, sids >= 7) diff --git a/zipline/pipeline/filters/__init__.py b/zipline/pipeline/filters/__init__.py index c88a2732..348714e3 100644 --- a/zipline/pipeline/filters/__init__.py +++ b/zipline/pipeline/filters/__init__.py @@ -8,6 +8,7 @@ from .filter import ( NumExprFilter, PercentileFilter, SingleAsset, + SpecificAssets, ) from .smoothing import All, Any, AtLeastN @@ -24,4 +25,5 @@ __all__ = [ 'NumExprFilter', 'PercentileFilter', 'SingleAsset', + 'SpecificAssets', ] diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index 13380dcd..e43c3bb3 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -4,11 +4,13 @@ filter.py from itertools import chain from operator import attrgetter +import numpy as np from numpy import ( float64, nan, nanpercentile, ) +import pandas as pd from zipline.errors import ( BadPercentileBounds, @@ -32,7 +34,7 @@ from zipline.pipeline.mixins import ( SingleInputMixin, ) from zipline.pipeline.term import ComputableTerm, Term -from zipline.utils.input_validation import expect_types +from zipline.utils.input_validation import coerce_types, expect_types from zipline.utils.memoize import classlazyval from zipline.utils.numpy_utils import bool_dtype, repeat_first_axis @@ -494,3 +496,22 @@ class SingleAsset(Filter): asset=self._asset, start_date=dates[0], end_date=dates[-1], ) return out + + +class SpecificAssets(Filter): + """ + A Filter that computes True for a specific set of predetermined assets. + """ + inputs = () + window_length = 0 + params = ('sids',) + + @expect_types(assets=(list, tuple, np.ndarray)) + @coerce_types(assets=((list, np.ndarray, pd.Series), list)) + def __new__(cls, assets): + sids = frozenset(asset.sid for asset in assets) + return super(SpecificAssets, cls).__new__(cls, sids=sids) + + def _compute(self, arrays, dates, sids, mask): + my_columns = sids.isin(self.params['sids']) + return repeat_first_axis(my_columns, len(mask)) & mask