From 28fdecc98bcdfa5ba5c22e7eb1e01ff815eb1426 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Mon, 18 Jan 2016 15:19:48 -0500 Subject: [PATCH] ENH: Make .latest return a Filter on bool columns. --- tests/pipeline/test_term.py | 14 +++++++- zipline/pipeline/data/dataset.py | 6 +++- zipline/pipeline/filters/__init__.py | 2 ++ zipline/pipeline/filters/filter.py | 48 +++++++++++++++++++++++++++- zipline/pipeline/filters/latest.py | 29 +++++++++++++++++ 5 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 zipline/pipeline/filters/latest.py diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index c657811a..a16a9129 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -12,11 +12,12 @@ from zipline.errors import ( TermInputsNotSpecified, WindowLengthNotSpecified, ) -from zipline.pipeline import Factor, TermGraph +from zipline.pipeline import Factor, Filter, TermGraph from zipline.pipeline.data import Column, DataSet from zipline.pipeline.term import AssetExists, NotSpecified from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS from zipline.utils.numpy_utils import ( + bool_dtype, datetime64ns_dtype, float64_dtype, ) @@ -331,6 +332,17 @@ class ObjectIdentityTestCase(TestCase): with self.assertRaises(InvalidDType): SomeFactor(dtype=1) + def test_latest_on_different_dtypes(self): + + class D(DataSet): + bool_col = Column(dtype=bool_dtype) + float_col = Column(dtype=float64_dtype) + datetime_col = Column(dtype=datetime64ns_dtype) + + self.assertIsInstance(D.bool_col.latest, Filter) + self.assertIsInstance(D.float_col.latest, Factor) + self.assertIsInstance(D.datetime_col.latest, Factor) + class SubDataSetTestCase(TestCase): def test_subdataset(self): diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index e65dca39..c420ea81 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -9,6 +9,7 @@ from six import ( from zipline.pipeline.term import Term, AssetExists from zipline.utils.input_validation import ensure_dtype +from zipline.utils.numpy_utils import bool_dtype from zipline.utils.preprocess import preprocess @@ -92,7 +93,10 @@ class BoundColumn(Term): @property def latest(self): - from zipline.pipeline.factors import Latest + if self.dtype == bool_dtype: + from zipline.pipeline.filters import Latest + else: + from zipline.pipeline.factors import Latest return Latest(inputs=(self,), dtype=self.dtype) def __repr__(self): diff --git a/zipline/pipeline/filters/__init__.py b/zipline/pipeline/filters/__init__.py index 3184bb9a..9ff7a0eb 100644 --- a/zipline/pipeline/filters/__init__.py +++ b/zipline/pipeline/filters/__init__.py @@ -1,7 +1,9 @@ from .filter import Filter, NumExprFilter, PercentileFilter +from .latest import Latest __all__ = [ 'Filter', + 'Latest', 'NumExprFilter', 'PercentileFilter', ] diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index f3c02924..2212f507 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -252,6 +252,52 @@ class PercentileFilter(SingleInputMixin, Filter): class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter): """ - Filter analog to ``CustomFactor``. + Base class for user-defined Filters. + + Parameters + ---------- + inputs : iterable, optional + An iterable of `BoundColumn` instances (e.g. USEquityPricing.close), + describing the data to load and pass to `self.compute`. If this + argument is passed to the CustomFilter constructor, we look for a + class-level attribute named `inputs`. + window_length : int, optional + Number of rows to pass for each input. If this argument is not passed + to the CustomFilter constructor, we look for a class-level attribute + named `window_length`. + + Notes + ----- + Users implementing their own Filters should subclass CustomFilter and + implement a method named `compute` with the following signature: + + .. code-block:: python + + def compute(self, today, assets, out, *inputs): + ... + + On each simulation date, ``compute`` will be called with the current date, + an array of sids, an output array, and an input array for each expression + passed as inputs to the CustomFilter constructor. + + The specific types of the values passed to `compute` are as follows:: + + today : np.datetime64[ns] + Row label for the last row of all arrays passed as `inputs`. + assets : np.array[int64, ndim=1] + Column labels for `out` and`inputs`. + out : np.array[bool, ndim=1] + Output array of the same shape as `assets`. `compute` should write + its desired return values into `out`. + *inputs : tuple of np.array + Raw data arrays corresponding to the values of `self.inputs`. + + See the documentation for + :class:`~zipline.pipeline.factors.factor.CustomFactor` for more details on + implementing a custom ``compute`` method. + + See Also + -------- + zipline.pipeline.factors.factor.CustomFactor """ ctx = nullctx() diff --git a/zipline/pipeline/filters/latest.py b/zipline/pipeline/filters/latest.py new file mode 100644 index 00000000..f9588804 --- /dev/null +++ b/zipline/pipeline/filters/latest.py @@ -0,0 +1,29 @@ +""" +Filter that produces the most most recently-known value of a boolean-valued +Column. +""" +from zipline.utils.numpy_utils import bool_dtype + +from .filter import CustomFilter +from ..mixins import SingleInputMixin + + +class Latest(SingleInputMixin, CustomFilter): + """ + Filter producing the most recently-known value of `inputs[0]` on each day. + """ + window_length = 1 + + def compute(self, today, assets, out, data): + out[:] = data[-1] + + def _validate(self): + if self.inputs[0].dtype != bool_dtype: + raise TypeError( + "{name} expected an input of dtype bool, " + "but got {not_bool} instead.".format( + name=type(self).__name__, + not_bool=self.inputs[0].dtype, + ) + ) + super(Latest, self)._validate()