mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-06 01:05:37 +08:00
ENH: Adds StaticSids pipeline filter (#1717)
Useful for avoiding the need to create Asset objects when sids are easier to use. This is based off the existing implementation of StaticAssets, and StaticAssets is now implemented as a wrapper around StaticSids.
This commit is contained in:
@@ -30,7 +30,13 @@ from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, Pipeline
|
||||
from zipline.pipeline.classifiers import Classifier
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.filters import All, Any, AtLeastN, StaticAssets
|
||||
from zipline.pipeline.filters import (
|
||||
All,
|
||||
Any,
|
||||
AtLeastN,
|
||||
StaticAssets,
|
||||
StaticSids,
|
||||
)
|
||||
from zipline.testing import parameter_space, permute_rows, ZiplineTestCase
|
||||
from zipline.testing.fixtures import WithSeededRandomPipelineEngine
|
||||
from zipline.testing.predicates import assert_equal
|
||||
@@ -825,29 +831,28 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
)
|
||||
|
||||
|
||||
class SidFactor(CustomFactor):
|
||||
"""A factor that just returns each asset's sid."""
|
||||
inputs = ()
|
||||
window_length = 1
|
||||
|
||||
def compute(self, today, sids, out):
|
||||
out[:] = sids
|
||||
|
||||
|
||||
class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = tuple(range(10))
|
||||
|
||||
def test_specific_assets(self):
|
||||
assets = self.asset_finder.retrieve_all(self.ASSET_FINDER_EQUITY_SIDS)
|
||||
|
||||
class SidFactor(CustomFactor):
|
||||
"""A factor that just returns each asset's sid."""
|
||||
inputs = ()
|
||||
window_length = 1
|
||||
|
||||
def compute(self, today, sids, out):
|
||||
out[:] = sids
|
||||
|
||||
def _check_filters(self, evens, odds, first_five, last_three):
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'sid': SidFactor(),
|
||||
'evens': StaticAssets(assets[::2]),
|
||||
'odds': StaticAssets(assets[1::2]),
|
||||
'first_five': StaticAssets(assets[:5]),
|
||||
'last_three': StaticAssets(assets[-3:]),
|
||||
'evens': evens,
|
||||
'odds': odds,
|
||||
'first_five': first_five,
|
||||
'last_three': last_three,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -861,6 +866,26 @@ class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
|
||||
assert_equal(results.first_five, sids < 5)
|
||||
assert_equal(results.last_three, sids >= 7)
|
||||
|
||||
def test_specific_assets(self):
|
||||
assets = self.asset_finder.retrieve_all(self.ASSET_FINDER_EQUITY_SIDS)
|
||||
|
||||
self._check_filters(
|
||||
evens=StaticAssets(assets[::2]),
|
||||
odds=StaticAssets(assets[1::2]),
|
||||
first_five=StaticAssets(assets[:5]),
|
||||
last_three=StaticAssets(assets[-3:]),
|
||||
)
|
||||
|
||||
def test_specific_sids(self):
|
||||
sids = self.ASSET_FINDER_EQUITY_SIDS
|
||||
|
||||
self._check_filters(
|
||||
evens=StaticSids(sids[::2]),
|
||||
odds=StaticSids(sids[1::2]),
|
||||
first_five=StaticSids(sids[:5]),
|
||||
last_three=StaticSids(sids[-3:]),
|
||||
)
|
||||
|
||||
|
||||
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
|
||||
def test_reversability(self):
|
||||
|
||||
@@ -9,6 +9,7 @@ from .filter import (
|
||||
PercentileFilter,
|
||||
SingleAsset,
|
||||
StaticAssets,
|
||||
StaticSids,
|
||||
)
|
||||
from .smoothing import All, Any, AtLeastN
|
||||
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
'PercentileFilter',
|
||||
'SingleAsset',
|
||||
'StaticAssets',
|
||||
'StaticSids',
|
||||
]
|
||||
|
||||
@@ -502,7 +502,33 @@ class SingleAsset(Filter):
|
||||
return out
|
||||
|
||||
|
||||
class StaticAssets(Filter):
|
||||
class StaticSids(Filter):
|
||||
"""
|
||||
A Filter that computes True for a specific set of predetermined sids.
|
||||
|
||||
``StaticSids`` is mostly useful for debugging or for interactively
|
||||
computing pipeline terms for a fixed set of sids that are known ahead of
|
||||
time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sids : iterable[int]
|
||||
An iterable of sids for which to filter.
|
||||
"""
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
params = ('sids',)
|
||||
|
||||
def __new__(cls, sids):
|
||||
sids = frozenset(sids)
|
||||
return super(StaticSids, cls).__new__(cls, sids=sids)
|
||||
|
||||
def _compute(self, arrays, dates, sids, mask):
|
||||
my_columns = sids.isin(self.params['sids'])
|
||||
return repeat_first_axis(my_columns, len(mask)) & mask
|
||||
|
||||
|
||||
class StaticAssets(StaticSids):
|
||||
"""
|
||||
A Filter that computes True for a specific set of predetermined assets.
|
||||
|
||||
@@ -515,14 +541,6 @@ class StaticAssets(Filter):
|
||||
assets : iterable[Asset]
|
||||
An iterable of assets for which to filter.
|
||||
"""
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
params = ('sids',)
|
||||
|
||||
def __new__(cls, assets):
|
||||
sids = frozenset(asset.sid for asset in assets)
|
||||
return super(StaticAssets, cls).__new__(cls, sids=sids)
|
||||
|
||||
def _compute(self, arrays, dates, sids, mask):
|
||||
my_columns = sids.isin(self.params['sids'])
|
||||
return repeat_first_axis(my_columns, len(mask)) & mask
|
||||
return super(StaticAssets, cls).__new__(cls, sids)
|
||||
|
||||
Reference in New Issue
Block a user