mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-06 05:14:38 +08:00
@@ -32,13 +32,20 @@ Enhancements
|
||||
factors use the new ``CashBuybackAuthorizations`` and
|
||||
``ShareBuybackAuthorizations`` datasets, respectively. (:issue:`1022`).
|
||||
|
||||
* Implemented :class:`zipline.pipeline.Classifier`, a new core pipeline API
|
||||
term representing grouping keys. Classifiers are primarily used by passing
|
||||
them as the ``groupby`` parameter to factor normalization methods.
|
||||
|
||||
* Added factor normalization methods:
|
||||
:meth:`zipline.pipeline.Factor.demean` and
|
||||
:meth:`zipline.pipeline.Factor.zscore`. (:issue:`1046`)
|
||||
|
||||
* Implemented :class:`zipline.pipeline.Classifier`, a new core pipeline API
|
||||
term representing grouping keys. Classifiers are primarily used by passing
|
||||
them as the ``groupby`` parameter to factor normalization methods.
|
||||
* Added :meth:`zipline.pipeline.Factor.quantiles`, a method for computing a
|
||||
Classifier from a Factor by partitioning into equally-sized buckets. Also
|
||||
added helpers for common quantile sizes
|
||||
(:meth:`zipline.pipeline.Factor.quartiles`,
|
||||
:meth:`zipline.pipeline.Factor.quartiles`, and
|
||||
:meth:`zipline.pipeline.Factor.deciles`) (:issue:`1075`).
|
||||
|
||||
Experimental Features
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
+13
-1
@@ -13,10 +13,11 @@ from pandas import date_range, Int64Index, DataFrame
|
||||
from pandas.util.testing import assert_series_equal
|
||||
from six import iteritems
|
||||
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline import Pipeline, TermGraph
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.testing import (
|
||||
check_arrays,
|
||||
ExplodingObject,
|
||||
gen_calendars,
|
||||
make_simple_equity_info,
|
||||
@@ -24,6 +25,7 @@ from zipline.testing import (
|
||||
tmp_asset_finder,
|
||||
)
|
||||
|
||||
from zipline.utils.functional import dzip_exact
|
||||
from zipline.utils.numpy_utils import (
|
||||
NaTD,
|
||||
make_datetime64D
|
||||
@@ -125,6 +127,16 @@ class BasePipelineTestCase(TestCase):
|
||||
initial_workspace,
|
||||
)
|
||||
|
||||
def check_terms(self, terms, expected, initial_workspace, mask):
|
||||
"""
|
||||
Compile the given terms into a TermGraph, compute it with
|
||||
initial_workspace, and compare the results with ``expected``.
|
||||
"""
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(graph, initial_workspace, mask)
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_arrays(res, exp)
|
||||
|
||||
def build_mask(self, array):
|
||||
"""
|
||||
Helper for constructing an AssetExists mask from a boolean-coercible
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
|
||||
from zipline.pipeline import Classifier
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.utils.numpy_utils import int64_dtype
|
||||
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
|
||||
class ClassifierTestCase(BasePipelineTestCase):
|
||||
|
||||
@parameter_space(mv=[-1, 0, 1, 999])
|
||||
def test_isnull(self, mv):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = mv
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
# There's no significance to the values here other than that they
|
||||
# contain a mix of missing and non-missing values.
|
||||
data = np.array([[-1, 1, 0, 2],
|
||||
[3, 0, 1, 0],
|
||||
[-5, 0, -1, 0],
|
||||
[-3, 1, 2, 2]], dtype=int64_dtype)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'isnull': c.isnull(),
|
||||
'notnull': c.notnull()
|
||||
},
|
||||
expected={
|
||||
'isnull': data == mv,
|
||||
'notnull': data != mv,
|
||||
},
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
@parameter_space(compval=[0, 1, 999])
|
||||
def test_eq(self, compval):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = -1
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
# There's no significance to the values here other than that they
|
||||
# contain a mix of the comparison value and other values.
|
||||
data = np.array([[-1, 1, 0, 2],
|
||||
[3, 0, 1, 0],
|
||||
[-5, 0, -1, 0],
|
||||
[-3, 1, 2, 2]], dtype=int64_dtype)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'eq': c.eq(compval),
|
||||
},
|
||||
expected={
|
||||
'eq': (data == compval),
|
||||
},
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
@parameter_space(missing=[-1, 0, 1])
|
||||
def test_disallow_comparison_to_missing_value(self, missing):
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = missing
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
C().eq(missing)
|
||||
errmsg = str(e.exception)
|
||||
self.assertEqual(
|
||||
errmsg,
|
||||
"Comparison against self.missing_value ({v}) in C.eq().\n"
|
||||
"Missing values have NaN semantics, so the requested comparison"
|
||||
" would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.".format(
|
||||
v=missing,
|
||||
),
|
||||
)
|
||||
|
||||
@parameter_space(compval=[0, 1, 999], missing=[-1, 0, 999])
|
||||
def test_not_equal(self, compval, missing):
|
||||
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = missing
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
# There's no significance to the values here other than that they
|
||||
# contain a mix of the comparison value and other values.
|
||||
data = np.array([[-1, 1, 0, 2],
|
||||
[3, 0, 1, 0],
|
||||
[-5, 0, -1, 0],
|
||||
[-3, 1, 2, 2]], dtype=int64_dtype)
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'ne': c != compval,
|
||||
},
|
||||
expected={
|
||||
'ne': (data != compval) & (data != C.missing_value),
|
||||
},
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
+328
-20
@@ -1,9 +1,12 @@
|
||||
"""
|
||||
Tests for Factor terms.
|
||||
"""
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from nose_parameterized import parameterized
|
||||
from unittest import TestCase
|
||||
|
||||
from toolz import compose
|
||||
from numpy import (
|
||||
apply_along_axis,
|
||||
arange,
|
||||
@@ -11,10 +14,10 @@ from numpy import (
|
||||
datetime64,
|
||||
empty,
|
||||
eye,
|
||||
log1p,
|
||||
nan,
|
||||
nanmean,
|
||||
nanstd,
|
||||
ones,
|
||||
rot90,
|
||||
where,
|
||||
)
|
||||
from numpy.random import randn, seed
|
||||
@@ -31,13 +34,16 @@ from zipline.testing import (
|
||||
check_allclose,
|
||||
check_arrays,
|
||||
parameter_space,
|
||||
permute_rows,
|
||||
)
|
||||
from zipline.utils.functional import dzip_exact
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
NaTns,
|
||||
)
|
||||
from zipline.utils.math_utils import nanmean, nanstd
|
||||
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
@@ -48,6 +54,12 @@ class F(Factor):
|
||||
window_length = 0
|
||||
|
||||
|
||||
class OtherF(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class C(Classifier):
|
||||
dtype = int64_dtype
|
||||
missing_value = -1
|
||||
@@ -423,18 +435,103 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
|
||||
check_arrays(float_result, datetime_result)
|
||||
|
||||
def test_normalizations_hand_computed(self):
|
||||
"""
|
||||
Test the hand-computed example in factor.demean.
|
||||
"""
|
||||
f = self.f
|
||||
m = Mask()
|
||||
c = C()
|
||||
|
||||
factor_data = array(
|
||||
[[1.0, 2.0, 3.0, 4.0],
|
||||
[1.5, 2.5, 3.5, 1.0],
|
||||
[2.0, 3.0, 4.0, 1.5],
|
||||
[2.5, 3.5, 1.0, 2.0]],
|
||||
)
|
||||
filter_data = array(
|
||||
[[False, True, True, True],
|
||||
[True, False, True, True],
|
||||
[True, True, False, True],
|
||||
[True, True, True, False]],
|
||||
dtype=bool,
|
||||
)
|
||||
classifier_data = array(
|
||||
[[1, 1, 2, 2],
|
||||
[1, 1, 2, 2],
|
||||
[1, 1, 2, 2],
|
||||
[1, 1, 2, 2]],
|
||||
dtype=int64_dtype,
|
||||
)
|
||||
|
||||
terms = {
|
||||
'vanilla': f.demean(),
|
||||
'masked': f.demean(mask=m),
|
||||
'grouped': f.demean(groupby=c),
|
||||
'grouped_masked': f.demean(mask=m, groupby=c),
|
||||
}
|
||||
expected = {
|
||||
'vanilla': array(
|
||||
[[-1.500, -0.500, 0.500, 1.500],
|
||||
[-0.625, 0.375, 1.375, -1.125],
|
||||
[-0.625, 0.375, 1.375, -1.125],
|
||||
[0.250, 1.250, -1.250, -0.250]],
|
||||
),
|
||||
'masked': array(
|
||||
[[nan, -1.000, 0.000, 1.000],
|
||||
[-0.500, nan, 1.500, -1.000],
|
||||
[-0.166, 0.833, nan, -0.666],
|
||||
[0.166, 1.166, -1.333, nan]],
|
||||
),
|
||||
'grouped': array(
|
||||
[[-0.500, 0.500, -0.500, 0.500],
|
||||
[-0.500, 0.500, 1.250, -1.250],
|
||||
[-0.500, 0.500, 1.250, -1.250],
|
||||
[-0.500, 0.500, -0.500, 0.500]],
|
||||
),
|
||||
'grouped_masked': array(
|
||||
[[nan, 0.000, -0.500, 0.500],
|
||||
[0.000, nan, 1.250, -1.250],
|
||||
[-0.500, 0.500, nan, 0.000],
|
||||
[-0.500, 0.500, 0.000, nan]]
|
||||
)
|
||||
}
|
||||
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
c: classifier_data,
|
||||
m: filter_data,
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=factor_data.shape)),
|
||||
)
|
||||
|
||||
for key, (res, exp) in dzip_exact(results, expected).items():
|
||||
check_allclose(
|
||||
res,
|
||||
exp,
|
||||
# The hand-computed values aren't very precise (in particular,
|
||||
# we truncate repeating decimals at 3 places) This is just
|
||||
# asserting that the example isn't misleading by being totally
|
||||
# wrong.
|
||||
atol=0.001,
|
||||
err_msg="Mismatch for %r" % key
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
seed_value=range(1, 2),
|
||||
normalizer_name_and_func=[
|
||||
('demean', lambda row: row - nanmean(row)),
|
||||
('zscore', lambda row: (row - nanmean(row)) / nanstd(row)),
|
||||
],
|
||||
add_nulls_to_factor=(False, True,)
|
||||
add_nulls_to_factor=(False, True,),
|
||||
)
|
||||
def test_normalizations(self,
|
||||
seed_value,
|
||||
normalizer_name_and_func,
|
||||
add_nulls_to_factor):
|
||||
def test_normalizations_randomized(self,
|
||||
seed_value,
|
||||
normalizer_name_and_func,
|
||||
add_nulls_to_factor):
|
||||
|
||||
name, func = normalizer_name_and_func
|
||||
|
||||
@@ -445,9 +542,9 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
# Falses on main diagonal.
|
||||
eyemask = self.eye_mask(shape=shape)
|
||||
# Falses on other diagonal.
|
||||
eyemask_T = eyemask.T
|
||||
eyemask90 = rot90(eyemask)
|
||||
# Falses on both diagonals.
|
||||
xmask = eyemask & eyemask_T
|
||||
xmask = eyemask & eyemask90
|
||||
|
||||
# Block of random data.
|
||||
factor_data = self.randn_data(seed=seed_value, shape=shape)
|
||||
@@ -456,12 +553,12 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
|
||||
# Cycles of 0, 1, 2, 0, 1, 2, ...
|
||||
classifier_data = (
|
||||
(self.arange_data(shape=shape, dtype=int) + seed_value) % 3
|
||||
(self.arange_data(shape=shape, dtype=int64_dtype) + seed_value) % 3
|
||||
)
|
||||
# With -1s on main diagonal.
|
||||
classifier_data_eyenulls = where(eyemask, classifier_data, -1)
|
||||
# With -1s on opposite diagonal.
|
||||
classifier_data_eyenulls_T = where(eyemask_T, classifier_data, -1)
|
||||
classifier_data_eyenulls90 = where(eyemask90, classifier_data, -1)
|
||||
# With -1s on both diagonals.
|
||||
classifier_data_xnulls = where(xmask, classifier_data, -1)
|
||||
|
||||
@@ -494,8 +591,8 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
# If the classifier has nulls, we should get NaNs in the
|
||||
# corresponding locations in the output.
|
||||
'grouped_with_nulls': where(
|
||||
eyemask_T,
|
||||
grouped_apply(factor_data, classifier_data_eyenulls_T, func),
|
||||
eyemask90,
|
||||
grouped_apply(factor_data, classifier_data_eyenulls90, func),
|
||||
nan,
|
||||
),
|
||||
# Passing a mask with a classifier should behave as though the
|
||||
@@ -520,21 +617,18 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
)
|
||||
}
|
||||
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
self.check_terms(
|
||||
terms=terms,
|
||||
expected=expected,
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
c: classifier_data,
|
||||
c_with_nulls: classifier_data_eyenulls_T,
|
||||
c_with_nulls: classifier_data_eyenulls90,
|
||||
Mask(): eyemask,
|
||||
},
|
||||
mask=self.build_mask(nomask),
|
||||
)
|
||||
|
||||
for key in expected:
|
||||
check_arrays(expected[key], results[key])
|
||||
|
||||
@parameter_space(method_name=['demean', 'zscore'])
|
||||
def test_cant_normalize_non_float(self, method_name):
|
||||
class DateFactor(Factor):
|
||||
@@ -553,3 +647,217 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
).format(normalizer=method_name)
|
||||
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
@parameter_space(seed=[1, 2, 3])
|
||||
def test_quantiles_unmasked(self, seed):
|
||||
permute = partial(permute_rows, seed)
|
||||
|
||||
shape = (6, 6)
|
||||
|
||||
# Shuffle the input rows to verify that we don't depend on the order.
|
||||
# Take the log to ensure that we don't depend on linear scaling or
|
||||
# integrality of inputs
|
||||
factor_data = permute(log1p(arange(36, dtype=float).reshape(shape)))
|
||||
|
||||
f = self.f
|
||||
|
||||
# Apply the same shuffle we applied to the input rows to our
|
||||
# expectations. Doing it this way makes it obvious that our
|
||||
# expectation corresponds to our input, while still testing against
|
||||
# a range of input orderings.
|
||||
permuted_array = compose(permute, partial(array, dtype=int64_dtype))
|
||||
self.check_terms(
|
||||
terms={
|
||||
'2': f.quantiles(bins=2),
|
||||
'3': f.quantiles(bins=3),
|
||||
'6': f.quantiles(bins=6),
|
||||
},
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
},
|
||||
expected={
|
||||
# The values in the input are all increasing, so the first half
|
||||
# of each row should be in the bottom bucket, and the second
|
||||
# half should be in the top bucket.
|
||||
'2': permuted_array([[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]]),
|
||||
# Similar for three buckets.
|
||||
'3': permuted_array([[0, 0, 1, 1, 2, 2],
|
||||
[0, 0, 1, 1, 2, 2],
|
||||
[0, 0, 1, 1, 2, 2],
|
||||
[0, 0, 1, 1, 2, 2],
|
||||
[0, 0, 1, 1, 2, 2],
|
||||
[0, 0, 1, 1, 2, 2]]),
|
||||
# In the limiting case, we just have every column different.
|
||||
'6': permuted_array([[0, 1, 2, 3, 4, 5],
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[0, 1, 2, 3, 4, 5],
|
||||
[0, 1, 2, 3, 4, 5]]),
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=shape)),
|
||||
)
|
||||
|
||||
@parameter_space(seed=[1, 2, 3])
|
||||
def test_quantiles_masked(self, seed):
|
||||
permute = partial(permute_rows, seed)
|
||||
|
||||
# 7 x 7 so that we divide evenly into 2/3/6-tiles after including the
|
||||
# nan value in each row.
|
||||
shape = (7, 7)
|
||||
|
||||
# Shuffle the input rows to verify that we don't depend on the order.
|
||||
# Take the log to ensure that we don't depend on linear scaling or
|
||||
# integrality of inputs
|
||||
factor_data = permute(log1p(arange(49, dtype=float).reshape(shape)))
|
||||
factor_data_w_nans = where(
|
||||
permute(rot90(self.eye_mask(shape=shape))),
|
||||
factor_data,
|
||||
nan,
|
||||
)
|
||||
mask_data = permute(self.eye_mask(shape=shape))
|
||||
|
||||
f = F()
|
||||
f_nans = OtherF()
|
||||
m = Mask()
|
||||
|
||||
# Apply the same shuffle we applied to the input rows to our
|
||||
# expectations. Doing it this way makes it obvious that our
|
||||
# expectation corresponds to our input, while still testing against
|
||||
# a range of input orderings.
|
||||
permuted_array = compose(permute, partial(array, dtype=int64_dtype))
|
||||
|
||||
self.check_terms(
|
||||
terms={
|
||||
'2_masked': f.quantiles(bins=2, mask=m),
|
||||
'3_masked': f.quantiles(bins=3, mask=m),
|
||||
'6_masked': f.quantiles(bins=6, mask=m),
|
||||
'2_nans': f_nans.quantiles(bins=2),
|
||||
'3_nans': f_nans.quantiles(bins=3),
|
||||
'6_nans': f_nans.quantiles(bins=6),
|
||||
},
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
f_nans: factor_data_w_nans,
|
||||
m: mask_data,
|
||||
},
|
||||
expected={
|
||||
# Expected results here are the same as in
|
||||
# test_quantiles_unmasked, except with diagonals of -1s
|
||||
# interpolated to match the effects of masking and/or input
|
||||
# nans.
|
||||
'2_masked': permuted_array([[-1, 0, 0, 0, 1, 1, 1],
|
||||
[0, -1, 0, 0, 1, 1, 1],
|
||||
[0, 0, -1, 0, 1, 1, 1],
|
||||
[0, 0, 0, -1, 1, 1, 1],
|
||||
[0, 0, 0, 1, -1, 1, 1],
|
||||
[0, 0, 0, 1, 1, -1, 1],
|
||||
[0, 0, 0, 1, 1, 1, -1]]),
|
||||
'3_masked': permuted_array([[-1, 0, 0, 1, 1, 2, 2],
|
||||
[0, -1, 0, 1, 1, 2, 2],
|
||||
[0, 0, -1, 1, 1, 2, 2],
|
||||
[0, 0, 1, -1, 1, 2, 2],
|
||||
[0, 0, 1, 1, -1, 2, 2],
|
||||
[0, 0, 1, 1, 2, -1, 2],
|
||||
[0, 0, 1, 1, 2, 2, -1]]),
|
||||
'6_masked': permuted_array([[-1, 0, 1, 2, 3, 4, 5],
|
||||
[0, -1, 1, 2, 3, 4, 5],
|
||||
[0, 1, -1, 2, 3, 4, 5],
|
||||
[0, 1, 2, -1, 3, 4, 5],
|
||||
[0, 1, 2, 3, -1, 4, 5],
|
||||
[0, 1, 2, 3, 4, -1, 5],
|
||||
[0, 1, 2, 3, 4, 5, -1]]),
|
||||
'2_nans': permuted_array([[0, 0, 0, 1, 1, 1, -1],
|
||||
[0, 0, 0, 1, 1, -1, 1],
|
||||
[0, 0, 0, 1, -1, 1, 1],
|
||||
[0, 0, 0, -1, 1, 1, 1],
|
||||
[0, 0, -1, 0, 1, 1, 1],
|
||||
[0, -1, 0, 0, 1, 1, 1],
|
||||
[-1, 0, 0, 0, 1, 1, 1]]),
|
||||
'3_nans': permuted_array([[0, 0, 1, 1, 2, 2, -1],
|
||||
[0, 0, 1, 1, 2, -1, 2],
|
||||
[0, 0, 1, 1, -1, 2, 2],
|
||||
[0, 0, 1, -1, 1, 2, 2],
|
||||
[0, 0, -1, 1, 1, 2, 2],
|
||||
[0, -1, 0, 1, 1, 2, 2],
|
||||
[-1, 0, 0, 1, 1, 2, 2]]),
|
||||
'6_nans': permuted_array([[0, 1, 2, 3, 4, 5, -1],
|
||||
[0, 1, 2, 3, 4, -1, 5],
|
||||
[0, 1, 2, 3, -1, 4, 5],
|
||||
[0, 1, 2, -1, 3, 4, 5],
|
||||
[0, 1, -1, 2, 3, 4, 5],
|
||||
[0, -1, 1, 2, 3, 4, 5],
|
||||
[-1, 0, 1, 2, 3, 4, 5]]),
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=shape)),
|
||||
)
|
||||
|
||||
def test_quantiles_uneven_buckets(self):
|
||||
permute = partial(permute_rows, 5)
|
||||
shape = (5, 5)
|
||||
|
||||
factor_data = permute(log1p(arange(25, dtype=float).reshape(shape)))
|
||||
mask_data = permute(self.eye_mask(shape=shape))
|
||||
|
||||
f = F()
|
||||
m = Mask()
|
||||
|
||||
permuted_array = compose(permute, partial(array, dtype=int64_dtype))
|
||||
self.check_terms(
|
||||
terms={
|
||||
'3_masked': f.quantiles(bins=3, mask=m),
|
||||
'7_masked': f.quantiles(bins=7, mask=m),
|
||||
},
|
||||
initial_workspace={
|
||||
f: factor_data,
|
||||
m: mask_data,
|
||||
},
|
||||
expected={
|
||||
'3_masked': permuted_array([[-1, 0, 0, 1, 2],
|
||||
[0, -1, 0, 1, 2],
|
||||
[0, 0, -1, 1, 2],
|
||||
[0, 0, 1, -1, 2],
|
||||
[0, 0, 1, 2, -1]]),
|
||||
'7_masked': permuted_array([[-1, 0, 2, 4, 6],
|
||||
[0, -1, 2, 4, 6],
|
||||
[0, 2, -1, 4, 6],
|
||||
[0, 2, 4, -1, 6],
|
||||
[0, 2, 4, 6, -1]]),
|
||||
},
|
||||
mask=self.build_mask(self.ones_mask(shape=shape)),
|
||||
)
|
||||
|
||||
def test_quantile_helpers(self):
|
||||
f = self.f
|
||||
m = Mask()
|
||||
|
||||
self.assertIs(f.quartiles(), f.quantiles(bins=4))
|
||||
self.assertIs(f.quartiles(mask=m), f.quantiles(bins=4, mask=m))
|
||||
self.assertIsNot(f.quartiles(), f.quartiles(mask=m))
|
||||
|
||||
self.assertIs(f.quintiles(), f.quantiles(bins=5))
|
||||
self.assertIs(f.quintiles(mask=m), f.quantiles(bins=5, mask=m))
|
||||
self.assertIsNot(f.quintiles(), f.quintiles(mask=m))
|
||||
|
||||
self.assertIs(f.deciles(), f.quantiles(bins=10))
|
||||
self.assertIs(f.deciles(mask=m), f.quantiles(bins=10, mask=m))
|
||||
self.assertIsNot(f.deciles(), f.deciles(mask=m))
|
||||
|
||||
|
||||
class ShortReprTestCase(TestCase):
|
||||
"""
|
||||
Tests for short_repr methods of Factors.
|
||||
"""
|
||||
|
||||
def test_demean(self):
|
||||
r = F().demean().short_repr()
|
||||
self.assertEqual(r, "GroupedRowTransform('demean')")
|
||||
|
||||
def test_zscore(self):
|
||||
r = F().zscore().short_repr()
|
||||
self.assertEqual(r, "GroupedRowTransform('zscore')")
|
||||
|
||||
@@ -6,12 +6,13 @@ from types import FunctionType
|
||||
from unittest import TestCase
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from numpy import arange, dtype
|
||||
from numpy import arange, array, dtype
|
||||
import pytz
|
||||
from six import PY3
|
||||
|
||||
from zipline.utils.preprocess import call, preprocess
|
||||
from zipline.utils.input_validation import (
|
||||
expect_dimensions,
|
||||
ensure_timezone,
|
||||
expect_element,
|
||||
expect_dtypes,
|
||||
@@ -367,3 +368,38 @@ class PreprocessTestCase(TestCase):
|
||||
with self.assertRaises(TypeError) as e:
|
||||
f('a')
|
||||
self.assertIs(e.exception, error)
|
||||
|
||||
def test_expect_dimensions(self):
|
||||
|
||||
@expect_dimensions(x=2)
|
||||
def foo(x, y):
|
||||
return x[0, 0]
|
||||
|
||||
self.assertEqual(foo(arange(1).reshape(1, 1), 10), 0)
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
foo(arange(1), 1)
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"{qualname}() expected a 2-D array for argument 'x', but got"
|
||||
" a 1-D array instead.".format(qualname=qualname(foo))
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
foo(arange(1).reshape(1, 1, 1), 1)
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"{qualname}() expected a 2-D array for argument 'x', but got"
|
||||
" a 3-D array instead.".format(qualname=qualname(foo))
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
foo(array(0), 1)
|
||||
errmsg = str(e.exception)
|
||||
expected = (
|
||||
"{qualname}() expected a 2-D array for argument 'x', but got"
|
||||
" a scalar instead.".format(qualname=qualname(foo))
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Algorithms for computing quantiles on numpy arrays.
|
||||
"""
|
||||
from numpy.lib import apply_along_axis
|
||||
from pandas import qcut
|
||||
|
||||
|
||||
def quantiles(data, nbins_or_partition_bounds):
|
||||
"""
|
||||
Compute rowwise array quantiles on an input.
|
||||
"""
|
||||
return apply_along_axis(
|
||||
qcut,
|
||||
1,
|
||||
data,
|
||||
q=nbins_or_partition_bounds, labels=False,
|
||||
)
|
||||
@@ -1,8 +1,15 @@
|
||||
from .classifier import Classifier, CustomClassifier, Everything, Latest
|
||||
from .classifier import (
|
||||
Classifier,
|
||||
CustomClassifier,
|
||||
Quantiles,
|
||||
Everything,
|
||||
Latest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Classifier',
|
||||
'CustomClassifier',
|
||||
'Everything',
|
||||
'Latest',
|
||||
'Quantiles',
|
||||
]
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
"""
|
||||
classifier.py
|
||||
"""
|
||||
from numpy import zeros, where
|
||||
from numbers import Number
|
||||
|
||||
from numpy import where, isnan, nan, zeros
|
||||
|
||||
from zipline.lib.quantiles import quantiles
|
||||
from zipline.pipeline.term import ComputableTerm
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.numpy_utils import int64_dtype
|
||||
|
||||
from ..filters import NullFilter, NumExprFilter
|
||||
from ..mixins import (
|
||||
CustomTermMixin,
|
||||
LatestMixin,
|
||||
PositiveWindowLengthMixin,
|
||||
RestrictedDTypeMixin
|
||||
RestrictedDTypeMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +32,60 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
"""
|
||||
ALLOWED_DTYPES = (int64_dtype,) # Used by RestrictedDTypeMixin
|
||||
|
||||
def isnull(self):
|
||||
"""
|
||||
A Filter producing True for values where this term has missing data.
|
||||
"""
|
||||
return NullFilter(self)
|
||||
|
||||
def notnull(self):
|
||||
"""
|
||||
A Filter producing True for values where this term has complete data.
|
||||
"""
|
||||
return ~self.isnull()
|
||||
|
||||
# We explicitly don't support classifier to classifier comparisons, since
|
||||
# the numbers likely don't mean the same thing. This may be relaxed in the
|
||||
# future, but for now we're starting conservatively.
|
||||
@expect_types(other=Number)
|
||||
def eq(self, other):
|
||||
"""
|
||||
Construct a Filter returning True for asset/date pairs where the output
|
||||
of ``self`` matches ``other.
|
||||
"""
|
||||
# We treat this as an error because missing_values have NaN semantics,
|
||||
# which means this would return an array of all False, which is almost
|
||||
# certainly not what the user wants.
|
||||
if other == self.missing_value:
|
||||
raise ValueError(
|
||||
"Comparison against self.missing_value ({value}) in"
|
||||
" {typename}.eq().\n"
|
||||
"Missing values have NaN semantics, so the "
|
||||
"requested comparison would always produce False.\n"
|
||||
"Use the isnull() method to check for missing values.".format(
|
||||
value=other,
|
||||
typename=(type(self).__name__),
|
||||
)
|
||||
)
|
||||
return NumExprFilter.create(
|
||||
"x_0 == {other}".format(other=int(other)),
|
||||
binds=(self,),
|
||||
)
|
||||
|
||||
@expect_types(other=Number)
|
||||
def __ne__(self, other):
|
||||
"""
|
||||
Construct a Filter returning True for asset/date pairs where the output
|
||||
of ``self`` matches ``other.
|
||||
"""
|
||||
return NumExprFilter.create(
|
||||
"((x_0 != {other}) & (x_0 != {missing}))".format(
|
||||
other=int(other),
|
||||
missing=self.missing_value,
|
||||
),
|
||||
binds=(self,),
|
||||
)
|
||||
|
||||
|
||||
class Everything(Classifier):
|
||||
"""
|
||||
@@ -44,6 +104,29 @@ class Everything(Classifier):
|
||||
)
|
||||
|
||||
|
||||
class Quantiles(SingleInputMixin, Classifier):
|
||||
"""
|
||||
A classifier computing quantiles over an input.
|
||||
"""
|
||||
params = ('bins',)
|
||||
dtype = int64_dtype
|
||||
window_length = 0
|
||||
missing_value = -1
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
data = arrays[0]
|
||||
bins = self.params['bins']
|
||||
to_bin = where(mask, data, nan)
|
||||
result = quantiles(to_bin, bins)
|
||||
# Write self.missing_value into nan locations, whether they were
|
||||
# generated by our input mask or not.
|
||||
result[isnan(result)] = self.missing_value
|
||||
return result.astype(int64_dtype)
|
||||
|
||||
def short_repr(self):
|
||||
return type(self).__name__ + '(%d)' % self.params['bins']
|
||||
|
||||
|
||||
class CustomClassifier(PositiveWindowLengthMixin, CustomTermMixin, Classifier):
|
||||
"""
|
||||
Base class for user-defined Classifiers.
|
||||
|
||||
@@ -8,7 +8,7 @@ from numbers import Number
|
||||
import numexpr
|
||||
from numexpr.necompiler import getExprNames
|
||||
from numpy import (
|
||||
empty,
|
||||
full,
|
||||
inf,
|
||||
)
|
||||
|
||||
@@ -229,7 +229,7 @@ class NumericalExpression(ComputableTerm):
|
||||
"""
|
||||
Compute our stored expression string with numexpr.
|
||||
"""
|
||||
out = empty(mask.shape, dtype=self.dtype)
|
||||
out = full(mask.shape, self.missing_value, dtype=self.dtype)
|
||||
# This writes directly into our output buffer.
|
||||
numexpr.evaluate(
|
||||
self._expr,
|
||||
|
||||
@@ -5,13 +5,13 @@ from functools import wraps
|
||||
from operator import attrgetter
|
||||
from numbers import Number
|
||||
|
||||
from numpy import inf, where, nanstd
|
||||
from numpy import inf, where
|
||||
from toolz import curry
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.normalize import naive_grouped_rowwise_apply
|
||||
from zipline.lib.rank import masked_rankdata_2d
|
||||
from zipline.pipeline.classifiers import Classifier, Everything
|
||||
from zipline.pipeline.classifiers import Classifier, Everything, Quantiles
|
||||
from zipline.pipeline.mixins import (
|
||||
CustomTermMixin,
|
||||
LatestMixin,
|
||||
@@ -43,7 +43,7 @@ from zipline.pipeline.filters import (
|
||||
NullFilter,
|
||||
)
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.math_utils import nanmean
|
||||
from zipline.utils.math_utils import nanmean, nanstd
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
coerce_to_dtype,
|
||||
@@ -576,8 +576,13 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
--------
|
||||
:meth:`pandas.DataFrame.groupby`
|
||||
"""
|
||||
# This is a named function so that it has a __name__ for use in the
|
||||
# graph repr of GroupedRowTransform.
|
||||
def demean(row):
|
||||
return row - nanmean(row)
|
||||
|
||||
return GroupedRowTransform(
|
||||
transform=lambda row: row - nanmean(row),
|
||||
transform=demean,
|
||||
factor=self,
|
||||
mask=mask,
|
||||
groupby=groupby,
|
||||
@@ -637,8 +642,13 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
--------
|
||||
:meth:`pandas.DataFrame.groupby`
|
||||
"""
|
||||
# This is a named function so that it has a __name__ for use in the
|
||||
# graph repr of GroupedRowTransform.
|
||||
def zscore(row):
|
||||
return (row - nanmean(row)) / nanstd(row)
|
||||
|
||||
return GroupedRowTransform(
|
||||
transform=lambda row: (row - nanmean(row)) / nanstd(row),
|
||||
transform=zscore,
|
||||
factor=self,
|
||||
mask=mask,
|
||||
groupby=groupby,
|
||||
@@ -685,6 +695,105 @@ class Factor(RestrictedDTypeMixin, ComputableTerm):
|
||||
"""
|
||||
return Rank(self, method=method, ascending=ascending, mask=mask)
|
||||
|
||||
@expect_types(bins=int, mask=(Filter, NotSpecifiedType))
|
||||
def quantiles(self, bins, mask=NotSpecified):
|
||||
"""
|
||||
Construct a Classifier computing quantiles of the output of ``self``.
|
||||
|
||||
Every non-NaN data point the output is labelled with an integer value
|
||||
from 0 to (bins - 1). NaNs are labelled with -1.
|
||||
|
||||
If ``mask`` is supplied, ignore data points in locations for which
|
||||
``mask`` produces False, and emit a label of -1 at those locations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bins : int
|
||||
Number of bins labels to compute.
|
||||
mask : zipline.pipeline.Filter, optional
|
||||
Mask of values to ignore when computing quantiles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quantiles : zipline.pipeline.classifiers.Quantiles
|
||||
A Classifier producing integer labels ranging from 0 to (bins - 1).
|
||||
"""
|
||||
if mask is NotSpecified:
|
||||
mask = self.mask
|
||||
return Quantiles(inputs=(self,), bins=bins, mask=mask)
|
||||
|
||||
@expect_types(mask=(Filter, NotSpecifiedType))
|
||||
def quartiles(self, mask=NotSpecified):
|
||||
"""
|
||||
Construct a Classifier computing quartiles over the output of ``self``.
|
||||
|
||||
Every non-NaN data point the output is labelled with a value of either
|
||||
0, 1, 2, or 3, corresponding to the first, second, third, or fourth
|
||||
quartile over each row. NaN data points are labelled with -1.
|
||||
|
||||
If ``mask`` is supplied, ignore data points in locations for which
|
||||
``mask`` produces False, and emit a label of -1 at those locations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask : zipline.pipeline.Filter, optional
|
||||
Mask of values to ignore when computing quartiles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quartiles : zipline.pipeline.classifiers.Quantiles
|
||||
A Classifier producing integer labels ranging from 0 to 3.
|
||||
"""
|
||||
return self.quantiles(bins=4, mask=mask)
|
||||
|
||||
@expect_types(mask=(Filter, NotSpecifiedType))
|
||||
def quintiles(self, mask=NotSpecified):
|
||||
"""
|
||||
Construct a Classifier computing quintile labels on ``self``.
|
||||
|
||||
Every non-NaN data point the output is labelled with a value of either
|
||||
0, 1, 2, or 3, 4, corresonding to quintiles over each row. NaN data
|
||||
points are labelled with -1.
|
||||
|
||||
If ``mask`` is supplied, ignore data points in locations for which
|
||||
``mask`` produces False, and emit a label of -1 at those locations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask : zipline.pipeline.Filter, optional
|
||||
Mask of values to ignore when computing quintiles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
quintiles : zipline.pipeline.classifiers.Quantiles
|
||||
A Classifier producing integer labels ranging from 0 to 4.
|
||||
"""
|
||||
return self.quantiles(bins=5, mask=mask)
|
||||
|
||||
@expect_types(mask=(Filter, NotSpecifiedType))
|
||||
def deciles(self, mask=NotSpecified):
|
||||
"""
|
||||
Construct a Classifier computing decile labels on ``self``.
|
||||
|
||||
Every non-NaN data point the output is labelled with a value from 0 to
|
||||
9 corresonding to deciles over each row. NaN data points are labelled
|
||||
with -1.
|
||||
|
||||
If ``mask`` is supplied, ignore data points in locations for which
|
||||
``mask`` produces False, and emit a label of -1 at those locations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask : zipline.pipeline.Filter, optional
|
||||
Mask of values to ignore when computing deciles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
deciles : zipline.pipeline.classifiers.Quantiles
|
||||
A Classifier producing integer labels ranging from 0 to 9.
|
||||
"""
|
||||
return self.quantiles(bins=10, mask=mask)
|
||||
|
||||
def top(self, N, mask=NotSpecified):
|
||||
"""
|
||||
Construct a Filter matching the top N asset values of self each day.
|
||||
@@ -923,6 +1032,13 @@ class GroupedRowTransform(Factor):
|
||||
self.missing_value,
|
||||
)
|
||||
|
||||
@property
|
||||
def transform_name(self):
|
||||
return self._transform.__name__
|
||||
|
||||
def short_repr(self):
|
||||
return type(self).__name__ + '(%r)' % self.transform_name
|
||||
|
||||
|
||||
class Rank(SingleInputMixin, Factor):
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,7 @@ from numexpr import evaluate
|
||||
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.mixins import SingleInputMixin
|
||||
from zipline.utils.control_flow import ignore_nanwarnings
|
||||
from zipline.utils.numpy_utils import ignore_nanwarnings
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.math_utils import (
|
||||
nanargmax,
|
||||
|
||||
@@ -82,7 +82,7 @@ def binary_operator(op):
|
||||
)
|
||||
elif isinstance(other, int): # Note that this is true for bool as well
|
||||
return NumExprFilter.create(
|
||||
"x_0 {op} ({constant})".format(op=op, constant=int(other)),
|
||||
"x_0 {op} {constant}".format(op=op, constant=int(other)),
|
||||
binds=(self,),
|
||||
)
|
||||
raise BadBinaryOperator(op, self, other)
|
||||
|
||||
@@ -21,6 +21,7 @@ from .core import ( # noqa
|
||||
make_trade_panel_for_asset_info,
|
||||
num_days_in_range,
|
||||
parameter_space,
|
||||
permute_rows,
|
||||
powerset,
|
||||
product_upper_triangle,
|
||||
seconds_to_timestamp,
|
||||
|
||||
+30
-6
@@ -31,6 +31,7 @@ from zipline.finance.order import ORDER_STATUS
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.tradingcalendar import trading_days
|
||||
|
||||
|
||||
@@ -408,7 +409,7 @@ def make_trade_panel_for_asset_info(dates,
|
||||
volume_step_by_date,
|
||||
volume_step_by_sid):
|
||||
"""
|
||||
Convert an asset info frame into a panel of trades, writing NaNs for
|
||||
|
||||
locations where assets did not exist.
|
||||
"""
|
||||
sids = list(asset_info.index)
|
||||
@@ -568,11 +569,17 @@ def check_allclose(actual,
|
||||
"""
|
||||
if type(actual) != type(desired):
|
||||
raise AssertionError("%s != %s" % (type(actual), type(desired)))
|
||||
return assert_allclose(actual, desired, rtol=rtol, atol=atol,
|
||||
err_msg=err_msg, verbose=verbose)
|
||||
return assert_allclose(
|
||||
actual,
|
||||
desired,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
err_msg=err_msg,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def check_arrays(x, y, err_msg='', verbose=True):
|
||||
def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
|
||||
"""
|
||||
Wrapper around np.testing.assert_array_equal that also verifies that inputs
|
||||
are ndarrays.
|
||||
@@ -581,8 +588,9 @@ def check_arrays(x, y, err_msg='', verbose=True):
|
||||
--------
|
||||
np.assert_array_equal
|
||||
"""
|
||||
if type(x) != type(y):
|
||||
raise AssertionError("%s != %s" % (type(x), type(y)))
|
||||
assert type(x) == type(y), "{x} != {y}".format(x=type(x), y=type(y))
|
||||
assert x.dtype == y.dtype, "{x.dtype} != {y.dtype}".format(x=x, y=y)
|
||||
|
||||
return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
|
||||
|
||||
|
||||
@@ -885,6 +893,22 @@ def parameter_space(**params):
|
||||
return decorator
|
||||
|
||||
|
||||
@expect_dimensions(array=2)
|
||||
def permute_rows(seed, array):
|
||||
"""
|
||||
Shuffle each row in ``array`` based on permutations generated by ``seed``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seed : int
|
||||
Seed for numpy.RandomState
|
||||
array : np.ndarray[ndim=2]
|
||||
Array over which to apply permutations.
|
||||
"""
|
||||
rand = np.random.RandomState(seed)
|
||||
return np.apply_along_axis(rand.permutation, 1, array)
|
||||
|
||||
|
||||
@nottest
|
||||
def make_test_handler(testcase, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
Control flow utilities.
|
||||
"""
|
||||
from six import iteritems
|
||||
from warnings import (
|
||||
catch_warnings,
|
||||
filterwarnings,
|
||||
)
|
||||
|
||||
|
||||
class nullctx(object):
|
||||
@@ -23,40 +19,6 @@ class nullctx(object):
|
||||
return False
|
||||
|
||||
|
||||
class WarningContext(object):
|
||||
"""
|
||||
Re-entrant contextmanager for contextually managing warnings.
|
||||
"""
|
||||
def __init__(self, *warning_specs):
|
||||
self._warning_specs = warning_specs
|
||||
self._catchers = []
|
||||
|
||||
def __enter__(self):
|
||||
catcher = catch_warnings()
|
||||
catcher.__enter__()
|
||||
self._catchers.append(catcher)
|
||||
for args, kwargs in self._warning_specs:
|
||||
filterwarnings(*args, **kwargs)
|
||||
return catcher
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
catcher = self._catchers.pop()
|
||||
return catcher.__exit__(*exc_info)
|
||||
|
||||
|
||||
def ignore_nanwarnings():
|
||||
"""
|
||||
Helper for building a WarningContext that ignores warnings from numpy's
|
||||
nanfunctions.
|
||||
"""
|
||||
return WarningContext(
|
||||
(
|
||||
('ignore',),
|
||||
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invert(d):
|
||||
"""
|
||||
Invert a dictionary into a dictionary of sets.
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
from pprint import pformat
|
||||
|
||||
from six import viewkeys
|
||||
from six.moves import map
|
||||
|
||||
|
||||
def mapall(funcs, seq):
|
||||
"""
|
||||
Parameters
|
||||
@@ -20,3 +26,60 @@ def mapall(funcs, seq):
|
||||
for func in funcs:
|
||||
for elem in seq:
|
||||
yield func(elem)
|
||||
|
||||
|
||||
def same(*values):
|
||||
"""
|
||||
Check if all values in a sequence are equal.
|
||||
|
||||
Returns True on empty sequences.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> same(1, 1, 1, 1)
|
||||
True
|
||||
>>> same(1, 2, 1)
|
||||
False
|
||||
>>> same()
|
||||
True
|
||||
"""
|
||||
if not values:
|
||||
return True
|
||||
first, rest = values[0], values[1:]
|
||||
return all(value == first for value in rest)
|
||||
|
||||
|
||||
def _format_unequal_keys(dicts):
|
||||
return pformat([sorted(d.keys()) for d in dicts])
|
||||
|
||||
|
||||
def dzip_exact(*dicts):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
*dicts : iterable[dict]
|
||||
A sequence of dicts all sharing the same keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
zipped : dict
|
||||
A dict whose keys are the union of all keys in *dicts, and whose values
|
||||
are tuples of length len(dicts) containing the result of looking up
|
||||
each key in each dict.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If dicts don't all have the same keys.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> result = dzip_exact({'a': 1, 'b': 2}, {'a': 3, 'b': 4})
|
||||
>>> result == {'a': (1, 3), 'b': (2, 4)}
|
||||
True
|
||||
"""
|
||||
if not same(*map(viewkeys, dicts)):
|
||||
raise ValueError(
|
||||
"dict keys not all equal:\n\n%s" % _format_unequal_keys(dicts)
|
||||
)
|
||||
return {k: tuple(d[k] for d in dicts) for k in dicts[0]}
|
||||
|
||||
@@ -159,41 +159,43 @@ def expect_dtypes(*_pos, **named):
|
||||
name=name, dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
def _expect_dtype(_dtype_or_dtype_tuple):
|
||||
"""
|
||||
Factory for dtype-checking functions that work the @preprocess
|
||||
decorator.
|
||||
"""
|
||||
# Slightly different messages for dtype and tuple of dtypes.
|
||||
if isinstance(_dtype_or_dtype_tuple, tuple):
|
||||
allowed_dtypes = _dtype_or_dtype_tuple
|
||||
else:
|
||||
allowed_dtypes = (_dtype_or_dtype_tuple,)
|
||||
template = (
|
||||
"%(funcname)s() expected a value with dtype {dtype_str} "
|
||||
"for argument '%(argname)s', but got %(actual)r instead."
|
||||
).format(dtype_str=' or '.join(repr(d.name) for d in allowed_dtypes))
|
||||
|
||||
def check_dtype(value):
|
||||
return getattr(value, 'dtype', None) not in allowed_dtypes
|
||||
|
||||
def display_bad_value(value):
|
||||
# If the bad value has a dtype, but it's wrong, show the dtype
|
||||
# name.
|
||||
try:
|
||||
return value.dtype.name
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
return make_check(
|
||||
exc_type=TypeError,
|
||||
template=template,
|
||||
pred=check_dtype,
|
||||
actual=display_bad_value,
|
||||
)
|
||||
|
||||
return preprocess(**valmap(_expect_dtype, named))
|
||||
|
||||
|
||||
def _expect_dtype(_dtype_or_dtype_tuple):
|
||||
"""
|
||||
Factory for dtype-checking functions that work the @preprocess decorator.
|
||||
"""
|
||||
# Slightly different messages for dtype and tuple of dtypes.
|
||||
if isinstance(_dtype_or_dtype_tuple, tuple):
|
||||
allowed_dtypes = _dtype_or_dtype_tuple
|
||||
else:
|
||||
allowed_dtypes = (_dtype_or_dtype_tuple,)
|
||||
template = (
|
||||
"%(funcname)s() expected a value with dtype {dtype_str} "
|
||||
"for argument '%(argname)s', but got %(actual)r instead."
|
||||
).format(dtype_str=' or '.join(repr(d.name) for d in allowed_dtypes))
|
||||
|
||||
def check_dtype(value):
|
||||
return getattr(value, 'dtype', None) not in allowed_dtypes
|
||||
|
||||
def display_bad_value(value):
|
||||
# If the bad value has a dtype, but it's wrong, show the dtype name.
|
||||
try:
|
||||
return value.dtype.name
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
return make_check(
|
||||
exc_type=TypeError,
|
||||
template=template,
|
||||
pred=check_dtype,
|
||||
actual=display_bad_value,
|
||||
)
|
||||
|
||||
|
||||
def expect_types(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected types.
|
||||
@@ -223,6 +225,26 @@ def expect_types(*_pos, **named):
|
||||
)
|
||||
)
|
||||
|
||||
def _expect_type(type_):
|
||||
# Slightly different messages for type and tuple of types.
|
||||
_template = (
|
||||
"%(funcname)s() expected a value of type {type_or_types} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
)
|
||||
if isinstance(type_, tuple):
|
||||
template = _template.format(
|
||||
type_or_types=' or '.join(map(_qualified_name, type_))
|
||||
)
|
||||
else:
|
||||
template = _template.format(type_or_types=_qualified_name(type_))
|
||||
|
||||
return make_check(
|
||||
TypeError,
|
||||
template,
|
||||
lambda v: not isinstance(v, type_),
|
||||
compose(_qualified_name, type),
|
||||
)
|
||||
|
||||
return preprocess(**valmap(_expect_type, named))
|
||||
|
||||
|
||||
@@ -273,30 +295,6 @@ def make_check(exc_type, template, pred, actual):
|
||||
return _check
|
||||
|
||||
|
||||
def _expect_type(type_):
|
||||
"""
|
||||
Factory for type-checking functions that work the @preprocess decorator.
|
||||
"""
|
||||
# Slightly different messages for type and tuple of types.
|
||||
_template = (
|
||||
"%(funcname)s() expected a value of type {type_or_types} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
)
|
||||
if isinstance(type_, tuple):
|
||||
template = _template.format(
|
||||
type_or_types=' or '.join(map(_qualified_name, type_))
|
||||
)
|
||||
else:
|
||||
template = _template.format(type_or_types=_qualified_name(type_))
|
||||
|
||||
return make_check(
|
||||
TypeError,
|
||||
template,
|
||||
lambda v: not isinstance(v, type_),
|
||||
compose(_qualified_name, type),
|
||||
)
|
||||
|
||||
|
||||
def optional(type_):
|
||||
"""
|
||||
Helper for use with `expect_types` when an input can be `type_` or `None`.
|
||||
@@ -350,9 +348,63 @@ def expect_element(*_pos, **named):
|
||||
if _pos:
|
||||
raise TypeError("expect_element() only takes keyword arguments.")
|
||||
|
||||
def _expect_element(collection):
|
||||
template = (
|
||||
"%(funcname)s() expected a value in {collection} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
).format(collection=collection)
|
||||
return make_check(
|
||||
ValueError,
|
||||
template,
|
||||
complement(op.contains(collection)),
|
||||
repr,
|
||||
)
|
||||
return preprocess(**valmap(_expect_element, named))
|
||||
|
||||
|
||||
def expect_dimensions(**dimensions):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs are numpy arrays with a
|
||||
specific dimensionality.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> from numpy import array
|
||||
>>> @expect_dimensions(x=1, y=2)
|
||||
... def foo(x, y):
|
||||
... return x[0] + y[0, 0]
|
||||
...
|
||||
>>> foo(array([1, 1]), array([[1, 1], [2, 2]]))
|
||||
2
|
||||
>>> foo(array([1, 1], array([1, 1])))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: foo() expected a 2-D array for argument 'y', but got a 1-D array instead. # noqa
|
||||
"""
|
||||
def _expect_dimension(expected_ndim):
|
||||
def _check(func, argname, argvalue):
|
||||
funcname = _qualified_name(func)
|
||||
actual_ndim = argvalue.ndim
|
||||
if actual_ndim != expected_ndim:
|
||||
if actual_ndim == 0:
|
||||
actual_repr = 'scalar'
|
||||
else:
|
||||
actual_repr = "%d-D array" % actual_ndim
|
||||
raise ValueError(
|
||||
"{func}() expected a {expected:d}-D array"
|
||||
" for argument {argname!r}, but got a {actual}"
|
||||
" instead.".format(
|
||||
func=funcname,
|
||||
expected=expected_ndim,
|
||||
argname=argname,
|
||||
actual=actual_repr,
|
||||
)
|
||||
)
|
||||
return argvalue
|
||||
return _check
|
||||
return preprocess(**valmap(_expect_dimension, dimensions))
|
||||
|
||||
|
||||
def coerce(from_, to, **to_kwargs):
|
||||
"""
|
||||
A preprocessing decorator that coerces inputs of a given type by passing
|
||||
@@ -391,16 +443,3 @@ def coerce(from_, to, **to_kwargs):
|
||||
|
||||
|
||||
coerce_string = partial(coerce, string_types)
|
||||
|
||||
|
||||
def _expect_element(collection):
|
||||
template = (
|
||||
"%(funcname)s() expected a value in {collection} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
).format(collection=collection)
|
||||
return make_check(
|
||||
ValueError,
|
||||
template,
|
||||
complement(op.contains(collection)),
|
||||
repr,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
Utilities for working with numpy arrays.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from warnings import (
|
||||
catch_warnings,
|
||||
filterwarnings,
|
||||
)
|
||||
|
||||
from numpy import (
|
||||
broadcast,
|
||||
busday_count,
|
||||
@@ -219,3 +224,37 @@ def busday_count_mask_NaT(begindates,
|
||||
# Fill in entries where either comparison was NaT with nan in the output.
|
||||
out[beginmask | endmask] = nan
|
||||
return out
|
||||
|
||||
|
||||
class WarningContext(object):
|
||||
"""
|
||||
Re-usable contextmanager for contextually managing warnings.
|
||||
"""
|
||||
def __init__(self, *warning_specs):
|
||||
self._warning_specs = warning_specs
|
||||
self._catchers = []
|
||||
|
||||
def __enter__(self):
|
||||
catcher = catch_warnings()
|
||||
catcher.__enter__()
|
||||
self._catchers.append(catcher)
|
||||
for args, kwargs in self._warning_specs:
|
||||
filterwarnings(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
catcher = self._catchers.pop()
|
||||
return catcher.__exit__(*exc_info)
|
||||
|
||||
|
||||
def ignore_nanwarnings():
|
||||
"""
|
||||
Helper for building a WarningContext that ignores warnings from numpy's
|
||||
nanfunctions.
|
||||
"""
|
||||
return WarningContext(
|
||||
(
|
||||
('ignore',),
|
||||
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user