mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 23:23:34 +08:00
ENH: Add SpecificAssets filter.
Adds a filter that matches a set of assets. Mainly useful for testing and debugging.
This commit is contained in:
@@ -26,11 +26,13 @@ from numpy import (
|
||||
from numpy.random import randn, seed as random_seed
|
||||
|
||||
from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor
|
||||
from zipline.pipeline import Filter, Factor, Pipeline
|
||||
from zipline.pipeline.classifiers import Classifier
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.filters import All, Any, AtLeastN
|
||||
from zipline.testing import parameter_space, permute_rows
|
||||
from zipline.pipeline.filters import All, Any, AtLeastN, SpecificAssets
|
||||
from zipline.testing import parameter_space, permute_rows, ZiplineTestCase
|
||||
from zipline.testing.fixtures import WithSeededRandomPipelineEngine
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
|
||||
@@ -820,3 +822,40 @@ class FilterTestCase(BasePipelineTestCase):
|
||||
},
|
||||
mask=self.build_mask(permute(rot90(self.eye_mask(shape=shape)))),
|
||||
)
|
||||
|
||||
|
||||
class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = tuple(range(10))
|
||||
|
||||
def test_specific_assets(self):
|
||||
assets = self.asset_finder.retrieve_all(self.ASSET_FINDER_EQUITY_SIDS)
|
||||
|
||||
class SidFactor(CustomFactor):
|
||||
"""A factor that just returns each asset's sid."""
|
||||
inputs = ()
|
||||
window_length = 1
|
||||
|
||||
def compute(self, today, sids, out):
|
||||
out[:] = sids
|
||||
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'sid': SidFactor(),
|
||||
'evens': SpecificAssets(assets[::2]),
|
||||
'odds': SpecificAssets(assets[1::2]),
|
||||
'first_five': SpecificAssets(assets[:5]),
|
||||
'last_three': SpecificAssets(assets[-3:]),
|
||||
},
|
||||
)
|
||||
|
||||
start, end = self.trading_days[[-10, -1]]
|
||||
results = self.run_pipeline(pipe, start, end).unstack()
|
||||
|
||||
sids = results.sid.astype(int64_dtype)
|
||||
|
||||
assert_equal(results.evens, ~(sids % 2).astype(bool))
|
||||
assert_equal(results.odds, (sids % 2).astype(bool))
|
||||
assert_equal(results.first_five, sids < 5)
|
||||
assert_equal(results.last_three, sids >= 7)
|
||||
|
||||
@@ -8,6 +8,7 @@ from .filter import (
|
||||
NumExprFilter,
|
||||
PercentileFilter,
|
||||
SingleAsset,
|
||||
SpecificAssets,
|
||||
)
|
||||
from .smoothing import All, Any, AtLeastN
|
||||
|
||||
@@ -24,4 +25,5 @@ __all__ = [
|
||||
'NumExprFilter',
|
||||
'PercentileFilter',
|
||||
'SingleAsset',
|
||||
'SpecificAssets',
|
||||
]
|
||||
|
||||
@@ -4,11 +4,13 @@ filter.py
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
import numpy as np
|
||||
from numpy import (
|
||||
float64,
|
||||
nan,
|
||||
nanpercentile,
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
from zipline.errors import (
|
||||
BadPercentileBounds,
|
||||
@@ -32,7 +34,7 @@ from zipline.pipeline.mixins import (
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.term import ComputableTerm, Term
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.input_validation import coerce_types, expect_types
|
||||
from zipline.utils.memoize import classlazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype, repeat_first_axis
|
||||
|
||||
@@ -494,3 +496,22 @@ class SingleAsset(Filter):
|
||||
asset=self._asset, start_date=dates[0], end_date=dates[-1],
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class SpecificAssets(Filter):
|
||||
"""
|
||||
A Filter that computes True for a specific set of predetermined assets.
|
||||
"""
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
params = ('sids',)
|
||||
|
||||
@expect_types(assets=(list, tuple, np.ndarray))
|
||||
@coerce_types(assets=((list, np.ndarray, pd.Series), list))
|
||||
def __new__(cls, assets):
|
||||
sids = frozenset(asset.sid for asset in assets)
|
||||
return super(SpecificAssets, cls).__new__(cls, sids=sids)
|
||||
|
||||
def _compute(self, arrays, dates, sids, mask):
|
||||
my_columns = sids.isin(self.params['sids'])
|
||||
return repeat_first_axis(my_columns, len(mask)) & mask
|
||||
|
||||
Reference in New Issue
Block a user