mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 06:13:40 +08:00
add groupby to rank, top, and bottom
This commit is contained in:
committed by
Scott Sanderson
parent
161922897e
commit
9e3404646e
@@ -336,6 +336,207 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
for method in results:
|
||||
check_arrays(expected[method], results[method])
|
||||
|
||||
def test_grouped_rank_ascending(self, factor_dtype=float64_dtype):
|
||||
|
||||
f = F(dtype=factor_dtype)
|
||||
c = C()
|
||||
str_c = C(dtype=categorical_dtype, missing_value=None)
|
||||
|
||||
# Generated with:
|
||||
# data = arange(25).reshape(5, 5).transpose() % 4
|
||||
data = array([[0, 1, 2, 3, 0],
|
||||
[1, 2, 3, 0, 1],
|
||||
[2, 3, 0, 1, 2],
|
||||
[3, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 0]], dtype=factor_dtype)
|
||||
|
||||
# Generated with:
|
||||
# classifier_data = arange(25).reshape(5, 5).transpose() % 2
|
||||
classifier_data = array([[0, 1, 0, 1, 0],
|
||||
[1, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 0],
|
||||
[1, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 0]], dtype=int64_dtype)
|
||||
string_classifier_data = LabelArray(
|
||||
classifier_data.astype(str).astype(object),
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
'ordinal': array(
|
||||
[[1., 1., 3., 2., 2.],
|
||||
[1., 2., 3., 1., 2.],
|
||||
[2., 2., 1., 1., 3.],
|
||||
[2., 1., 1., 2., 3.],
|
||||
[1., 1., 3., 2., 2.]]
|
||||
),
|
||||
'average': array(
|
||||
[[1.5, 1., 3., 2., 1.5],
|
||||
[1.5, 2., 3., 1., 1.5],
|
||||
[2.5, 2., 1., 1., 2.5],
|
||||
[2.5, 1., 1., 2., 2.5],
|
||||
[1.5, 1., 3., 2., 1.5]]
|
||||
),
|
||||
'min': array(
|
||||
[[1., 1., 3., 2., 1.],
|
||||
[1., 2., 3., 1., 1.],
|
||||
[2., 2., 1., 1., 2.],
|
||||
[2., 1., 1., 2., 2.],
|
||||
[1., 1., 3., 2., 1.]]
|
||||
),
|
||||
'max': array(
|
||||
[[2., 1., 3., 2., 2.],
|
||||
[2., 2., 3., 1., 2.],
|
||||
[3., 2., 1., 1., 3.],
|
||||
[3., 1., 1., 2., 3.],
|
||||
[2., 1., 3., 2., 2.]]
|
||||
),
|
||||
'dense': array(
|
||||
[[1., 1., 2., 2., 1.],
|
||||
[1., 2., 2., 1., 1.],
|
||||
[2., 2., 1., 1., 2.],
|
||||
[2., 1., 1., 2., 2.],
|
||||
[1., 1., 2., 2., 1.]]
|
||||
),
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
str_c: string_classifier_data,
|
||||
},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
# Not specifying the value of ascending param should default to True
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=True)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
check({'ordinal': f.rank(groupby=c)})
|
||||
check({'ordinal': f.rank(groupby=str_c)})
|
||||
check({'ordinal': f.rank(groupby=c, ascending=True)})
|
||||
check({'ordinal': f.rank(groupby=str_c, ascending=True)})
|
||||
|
||||
def test_grouped_rank_descending(self, factor_dtype=float64_dtype):
|
||||
|
||||
f = F(dtype=factor_dtype)
|
||||
c = C()
|
||||
str_c = C(dtype=categorical_dtype, missing_value=None)
|
||||
|
||||
# Generated with:
|
||||
# data = arange(25).reshape(5, 5).transpose() % 4
|
||||
data = array([[0, 1, 2, 3, 0],
|
||||
[1, 2, 3, 0, 1],
|
||||
[2, 3, 0, 1, 2],
|
||||
[3, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 0]], dtype=factor_dtype)
|
||||
|
||||
# Generated with:
|
||||
# classifier_data = arange(25).reshape(5, 5).transpose() % 2
|
||||
classifier_data = array([[0, 1, 0, 1, 0],
|
||||
[1, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 0],
|
||||
[1, 0, 1, 0, 1],
|
||||
[0, 1, 0, 1, 0]], dtype=int64_dtype)
|
||||
|
||||
string_classifier_data = LabelArray(
|
||||
classifier_data.astype(str).astype(object),
|
||||
missing_value=None,
|
||||
)
|
||||
|
||||
expected_grouped_ranks = {
|
||||
'ordinal': array(
|
||||
[[2., 2., 1., 1., 3.],
|
||||
[2., 1., 1., 2., 3.],
|
||||
[1., 1., 3., 2., 2.],
|
||||
[1., 2., 3., 1., 2.],
|
||||
[2., 2., 1., 1., 3.]]
|
||||
),
|
||||
'average': array(
|
||||
[[2.5, 2., 1., 1., 2.5],
|
||||
[2.5, 1., 1., 2., 2.5],
|
||||
[1.5, 1., 3., 2., 1.5],
|
||||
[1.5, 2., 3., 1., 1.5],
|
||||
[2.5, 2., 1., 1., 2.5]]
|
||||
),
|
||||
'min': array(
|
||||
[[2., 2., 1., 1., 2.],
|
||||
[2., 1., 1., 2., 2.],
|
||||
[1., 1., 3., 2., 1.],
|
||||
[1., 2., 3., 1., 1.],
|
||||
[2., 2., 1., 1., 2.]]
|
||||
),
|
||||
'max': array(
|
||||
[[3., 2., 1., 1., 3.],
|
||||
[3., 1., 1., 2., 3.],
|
||||
[2., 1., 3., 2., 2.],
|
||||
[2., 2., 3., 1., 2.],
|
||||
[3., 2., 1., 1., 3.]]
|
||||
),
|
||||
'dense': array(
|
||||
[[2., 2., 1., 1., 2.],
|
||||
[2., 1., 1., 2., 2.],
|
||||
[1., 1., 2., 2., 1.],
|
||||
[1., 2., 2., 1., 1.],
|
||||
[2., 2., 1., 1., 2.]]
|
||||
),
|
||||
}
|
||||
|
||||
def check(terms):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={
|
||||
f: data,
|
||||
c: classifier_data,
|
||||
str_c: string_classifier_data,
|
||||
},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_grouped_ranks[method])
|
||||
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
check({
|
||||
meth: f.rank(method=meth, groupby=str_c, ascending=False)
|
||||
for meth in expected_grouped_ranks
|
||||
})
|
||||
|
||||
# Not passing a method should default to ordinal
|
||||
check({'ordinal': f.rank(groupby=c, ascending=False)})
|
||||
check({'ordinal': f.rank(groupby=str_c, ascending=False)})
|
||||
|
||||
# TODO finish this
|
||||
# @for_each_factor_dtype
|
||||
# def test_grouped_rank_after_mask(self, name, factor_dtype):
|
||||
# pass
|
||||
|
||||
@parameterized.expand([
|
||||
# Test cases computed by doing:
|
||||
# from numpy.random import seed, randn
|
||||
|
||||
@@ -6,6 +6,7 @@ from operator import attrgetter
|
||||
from numbers import Number
|
||||
|
||||
from numpy import inf, where
|
||||
from scipy.stats import rankdata
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.normalize import naive_grouped_rowwise_apply
|
||||
@@ -581,7 +582,11 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
window_safe=True,
|
||||
)
|
||||
|
||||
def rank(self, method='ordinal', ascending=True, mask=NotSpecified):
|
||||
def rank(self,
|
||||
method='ordinal',
|
||||
ascending=True,
|
||||
mask=NotSpecified,
|
||||
groupby=NotSpecified):
|
||||
"""
|
||||
Construct a new Factor representing the sorted rank of each column
|
||||
within each row.
|
||||
@@ -599,6 +604,8 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
A Filter representing assets to consider when computing ranks.
|
||||
If mask is supplied, ranks are computed ignoring any asset/date
|
||||
pairs for which `mask` produces a value of False.
|
||||
groupby : zipline.pipeline.Classifier, optional
|
||||
A classifier defining partitions over which to perform ranking.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -620,7 +627,21 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
:func:`scipy.stats.rankdata`
|
||||
:class:`zipline.pipeline.factors.factor.Rank`
|
||||
"""
|
||||
return Rank(self, method=method, ascending=ascending, mask=mask)
|
||||
|
||||
if groupby is NotSpecified:
|
||||
return Rank(self, method=method, ascending=ascending, mask=mask)
|
||||
|
||||
else:
|
||||
def rank(row):
|
||||
return rankdata(row if ascending else -row, method=method)
|
||||
|
||||
return GroupedRowTransform(
|
||||
transform=rank,
|
||||
factor=self,
|
||||
mask=mask,
|
||||
groupby=groupby,
|
||||
window_safe=True,
|
||||
)
|
||||
|
||||
@expect_types(
|
||||
target=Term, correlation_length=int, mask=(Filter, NotSpecifiedType),
|
||||
@@ -913,7 +934,7 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
"""
|
||||
return self.quantiles(bins=10, mask=mask)
|
||||
|
||||
def top(self, N, mask=NotSpecified):
|
||||
def top(self, N, mask=NotSpecified, groupby=NotSpecified):
|
||||
"""
|
||||
Construct a Filter matching the top N asset values of self each day.
|
||||
|
||||
@@ -925,14 +946,16 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
A Filter representing assets to consider when computing ranks.
|
||||
If mask is supplied, top values are computed ignoring any
|
||||
asset/date pairs for which `mask` produces a value of False.
|
||||
groupby : zipline.pipeline.Classifier, optional
|
||||
A classifier defining partitions over which to perform ranking.
|
||||
|
||||
Returns
|
||||
-------
|
||||
filter : zipline.pipeline.filters.Filter
|
||||
"""
|
||||
return self.rank(ascending=False, mask=mask) <= N
|
||||
return self.rank(ascending=False, mask=mask, groupby=groupby) <= N
|
||||
|
||||
def bottom(self, N, mask=NotSpecified):
|
||||
def bottom(self, N, mask=NotSpecified, groupby=NotSpecified):
|
||||
"""
|
||||
Construct a Filter matching the bottom N asset values of self each day.
|
||||
|
||||
@@ -944,12 +967,14 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
A Filter representing assets to consider when computing ranks.
|
||||
If mask is supplied, bottom values are computed ignoring any
|
||||
asset/date pairs for which `mask` produces a value of False.
|
||||
groupby : zipline.pipeline.Classifier, optional
|
||||
A classifier defining partitions over which to perform ranking.
|
||||
|
||||
Returns
|
||||
-------
|
||||
filter : zipline.pipeline.Filter
|
||||
"""
|
||||
return self.rank(ascending=True, mask=mask) <= N
|
||||
return self.rank(ascending=True, mask=mask, groupby=groupby) <= N
|
||||
|
||||
def percentile_between(self,
|
||||
min_percentile,
|
||||
@@ -1075,7 +1100,7 @@ class GroupedRowTransform(Factor):
|
||||
Factor.
|
||||
|
||||
This is most often useful for normalization operators like ``zscore`` or
|
||||
``demean``.
|
||||
``demean`` or for performing ranking using ``rank``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1093,12 +1118,13 @@ class GroupedRowTransform(Factor):
|
||||
-----
|
||||
Users should rarely construct instances of this factor directly. Instead,
|
||||
they should construct instances via factor normalization methods like
|
||||
``zscore`` and ``demean``.
|
||||
``zscore`` and ``demean`` or using ``rank`` with ``groupby``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.pipeline.factors.Factor.zscore
|
||||
zipline.pipeline.factors.Factor.demean
|
||||
zipline.pipeline.factors.Factor.rank
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user