ENH: Make .latest return a Filter on bool columns.

This commit is contained in:
Scott Sanderson
2016-01-18 15:19:48 -05:00
parent d5bd2a9fb8
commit 28fdecc98b
5 changed files with 96 additions and 3 deletions
+13 -1
View File
@@ -12,11 +12,12 @@ from zipline.errors import (
TermInputsNotSpecified,
WindowLengthNotSpecified,
)
from zipline.pipeline import Factor, TermGraph
from zipline.pipeline import Factor, Filter, TermGraph
from zipline.pipeline.data import Column, DataSet
from zipline.pipeline.term import AssetExists, NotSpecified
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
from zipline.utils.numpy_utils import (
bool_dtype,
datetime64ns_dtype,
float64_dtype,
)
@@ -331,6 +332,17 @@ class ObjectIdentityTestCase(TestCase):
with self.assertRaises(InvalidDType):
SomeFactor(dtype=1)
def test_latest_on_different_dtypes(self):
class D(DataSet):
bool_col = Column(dtype=bool_dtype)
float_col = Column(dtype=float64_dtype)
datetime_col = Column(dtype=datetime64ns_dtype)
self.assertIsInstance(D.bool_col.latest, Filter)
self.assertIsInstance(D.float_col.latest, Factor)
self.assertIsInstance(D.datetime_col.latest, Factor)
class SubDataSetTestCase(TestCase):
def test_subdataset(self):
+5 -1
View File
@@ -9,6 +9,7 @@ from six import (
from zipline.pipeline.term import Term, AssetExists
from zipline.utils.input_validation import ensure_dtype
from zipline.utils.numpy_utils import bool_dtype
from zipline.utils.preprocess import preprocess
@@ -92,7 +93,10 @@ class BoundColumn(Term):
@property
def latest(self):
from zipline.pipeline.factors import Latest
if self.dtype == bool_dtype:
from zipline.pipeline.filters import Latest
else:
from zipline.pipeline.factors import Latest
return Latest(inputs=(self,), dtype=self.dtype)
def __repr__(self):
+2
View File
@@ -1,7 +1,9 @@
from .filter import Filter, NumExprFilter, PercentileFilter
from .latest import Latest
__all__ = [
'Filter',
'Latest',
'NumExprFilter',
'PercentileFilter',
]
+47 -1
View File
@@ -252,6 +252,52 @@ class PercentileFilter(SingleInputMixin, Filter):
class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter):
"""
Filter analog to ``CustomFactor``.
Base class for user-defined Filters.
Parameters
----------
inputs : iterable, optional
An iterable of `BoundColumn` instances (e.g. USEquityPricing.close),
describing the data to load and pass to `self.compute`. If this
argument is passed to the CustomFilter constructor, we look for a
class-level attribute named `inputs`.
window_length : int, optional
Number of rows to pass for each input. If this argument is not passed
to the CustomFilter constructor, we look for a class-level attribute
named `window_length`.
Notes
-----
Users implementing their own Filters should subclass CustomFilter and
implement a method named `compute` with the following signature:
.. code-block:: python
def compute(self, today, assets, out, *inputs):
...
On each simulation date, ``compute`` will be called with the current date,
an array of sids, an output array, and an input array for each expression
passed as inputs to the CustomFilter constructor.
The specific types of the values passed to `compute` are as follows::
today : np.datetime64[ns]
Row label for the last row of all arrays passed as `inputs`.
assets : np.array[int64, ndim=1]
Column labels for `out` and`inputs`.
out : np.array[bool, ndim=1]
Output array of the same shape as `assets`. `compute` should write
its desired return values into `out`.
*inputs : tuple of np.array
Raw data arrays corresponding to the values of `self.inputs`.
See the documentation for
:class:`~zipline.pipeline.factors.factor.CustomFactor` for more details on
implementing a custom ``compute`` method.
See Also
--------
zipline.pipeline.factors.factor.CustomFactor
"""
ctx = nullctx()
+29
View File
@@ -0,0 +1,29 @@
"""
Filter that produces the most most recently-known value of a boolean-valued
Column.
"""
from zipline.utils.numpy_utils import bool_dtype
from .filter import CustomFilter
from ..mixins import SingleInputMixin
class Latest(SingleInputMixin, CustomFilter):
"""
Filter producing the most recently-known value of `inputs[0]` on each day.
"""
window_length = 1
def compute(self, today, assets, out, data):
out[:] = data[-1]
def _validate(self):
if self.inputs[0].dtype != bool_dtype:
raise TypeError(
"{name} expected an input of dtype bool, "
"but got {not_bool} instead.".format(
name=type(self).__name__,
not_bool=self.inputs[0].dtype,
)
)
super(Latest, self)._validate()