mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 17:22:40 +08:00
ENH: Make .latest return a Filter on bool columns.
This commit is contained in:
@@ -12,11 +12,12 @@ from zipline.errors import (
|
||||
TermInputsNotSpecified,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.pipeline import Factor, TermGraph
|
||||
from zipline.pipeline import Factor, Filter, TermGraph
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.term import AssetExists, NotSpecified
|
||||
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
)
|
||||
@@ -331,6 +332,17 @@ class ObjectIdentityTestCase(TestCase):
|
||||
with self.assertRaises(InvalidDType):
|
||||
SomeFactor(dtype=1)
|
||||
|
||||
def test_latest_on_different_dtypes(self):
|
||||
|
||||
class D(DataSet):
|
||||
bool_col = Column(dtype=bool_dtype)
|
||||
float_col = Column(dtype=float64_dtype)
|
||||
datetime_col = Column(dtype=datetime64ns_dtype)
|
||||
|
||||
self.assertIsInstance(D.bool_col.latest, Filter)
|
||||
self.assertIsInstance(D.float_col.latest, Factor)
|
||||
self.assertIsInstance(D.datetime_col.latest, Factor)
|
||||
|
||||
|
||||
class SubDataSetTestCase(TestCase):
|
||||
def test_subdataset(self):
|
||||
|
||||
@@ -9,6 +9,7 @@ from six import (
|
||||
|
||||
from zipline.pipeline.term import Term, AssetExists
|
||||
from zipline.utils.input_validation import ensure_dtype
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
@@ -92,7 +93,10 @@ class BoundColumn(Term):
|
||||
|
||||
@property
|
||||
def latest(self):
|
||||
from zipline.pipeline.factors import Latest
|
||||
if self.dtype == bool_dtype:
|
||||
from zipline.pipeline.filters import Latest
|
||||
else:
|
||||
from zipline.pipeline.factors import Latest
|
||||
return Latest(inputs=(self,), dtype=self.dtype)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from .filter import Filter, NumExprFilter, PercentileFilter
|
||||
from .latest import Latest
|
||||
|
||||
__all__ = [
|
||||
'Filter',
|
||||
'Latest',
|
||||
'NumExprFilter',
|
||||
'PercentileFilter',
|
||||
]
|
||||
|
||||
@@ -252,6 +252,52 @@ class PercentileFilter(SingleInputMixin, Filter):
|
||||
|
||||
class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter):
|
||||
"""
|
||||
Filter analog to ``CustomFactor``.
|
||||
Base class for user-defined Filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : iterable, optional
|
||||
An iterable of `BoundColumn` instances (e.g. USEquityPricing.close),
|
||||
describing the data to load and pass to `self.compute`. If this
|
||||
argument is passed to the CustomFilter constructor, we look for a
|
||||
class-level attribute named `inputs`.
|
||||
window_length : int, optional
|
||||
Number of rows to pass for each input. If this argument is not passed
|
||||
to the CustomFilter constructor, we look for a class-level attribute
|
||||
named `window_length`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Users implementing their own Filters should subclass CustomFilter and
|
||||
implement a method named `compute` with the following signature:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def compute(self, today, assets, out, *inputs):
|
||||
...
|
||||
|
||||
On each simulation date, ``compute`` will be called with the current date,
|
||||
an array of sids, an output array, and an input array for each expression
|
||||
passed as inputs to the CustomFilter constructor.
|
||||
|
||||
The specific types of the values passed to `compute` are as follows::
|
||||
|
||||
today : np.datetime64[ns]
|
||||
Row label for the last row of all arrays passed as `inputs`.
|
||||
assets : np.array[int64, ndim=1]
|
||||
Column labels for `out` and`inputs`.
|
||||
out : np.array[bool, ndim=1]
|
||||
Output array of the same shape as `assets`. `compute` should write
|
||||
its desired return values into `out`.
|
||||
*inputs : tuple of np.array
|
||||
Raw data arrays corresponding to the values of `self.inputs`.
|
||||
|
||||
See the documentation for
|
||||
:class:`~zipline.pipeline.factors.factor.CustomFactor` for more details on
|
||||
implementing a custom ``compute`` method.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.pipeline.factors.factor.CustomFactor
|
||||
"""
|
||||
ctx = nullctx()
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Filter that produces the most most recently-known value of a boolean-valued
|
||||
Column.
|
||||
"""
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
|
||||
from .filter import CustomFilter
|
||||
from ..mixins import SingleInputMixin
|
||||
|
||||
|
||||
class Latest(SingleInputMixin, CustomFilter):
|
||||
"""
|
||||
Filter producing the most recently-known value of `inputs[0]` on each day.
|
||||
"""
|
||||
window_length = 1
|
||||
|
||||
def compute(self, today, assets, out, data):
|
||||
out[:] = data[-1]
|
||||
|
||||
def _validate(self):
|
||||
if self.inputs[0].dtype != bool_dtype:
|
||||
raise TypeError(
|
||||
"{name} expected an input of dtype bool, "
|
||||
"but got {not_bool} instead.".format(
|
||||
name=type(self).__name__,
|
||||
not_bool=self.inputs[0].dtype,
|
||||
)
|
||||
)
|
||||
super(Latest, self)._validate()
|
||||
Reference in New Issue
Block a user