mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 21:02:39 +08:00
BUG: Make NumExprFilter return ndarray.
- Previously it was returning a DataFrame because of how we applied an & with a DataFrame mask. The error was masked by the fact that `np.assert_array_equal` coerces inputs to arrays before comparing. - Added `zp.utils.test_utils.check_arrays`, which checks type equality before calling `np.assert_array_equal`.
This commit is contained in:
@@ -18,7 +18,6 @@ from numpy import (
|
||||
isnan,
|
||||
zeros,
|
||||
)
|
||||
from numpy.testing import assert_array_equal
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
date_range,
|
||||
@@ -30,6 +29,7 @@ from zipline.modelling.expression import (
|
||||
NUMEXPR_MATH_FUNCS,
|
||||
)
|
||||
from zipline.modelling.factor import TestingFactor
|
||||
from zipline.utils.test_utils import check_arrays
|
||||
|
||||
|
||||
class F(TestingFactor):
|
||||
@@ -67,7 +67,7 @@ class NumericalExpressionTestCase(TestCase):
|
||||
[self.fake_raw_data[input_] for input_ in expr.inputs],
|
||||
self.mask,
|
||||
)
|
||||
assert_array_equal(result, full((5, 5), expected))
|
||||
check_arrays(result, expected)
|
||||
|
||||
def check_constant_output(self, expr, expected):
|
||||
self.assertFalse(isnan(expected))
|
||||
|
||||
@@ -128,11 +128,10 @@ class NumExprFilter(NumericalExpression, Filter):
|
||||
"""
|
||||
Compute our result with numexpr, then apply `mask`.
|
||||
"""
|
||||
numexpr_result = super(NumExprFilter, self).compute_from_arrays(
|
||||
return super(NumExprFilter, self).compute_from_arrays(
|
||||
arrays,
|
||||
mask,
|
||||
)
|
||||
return numexpr_result & mask
|
||||
) & mask.values
|
||||
|
||||
|
||||
class PercentileFilter(SingleInputMixin, Filter):
|
||||
|
||||
@@ -4,6 +4,7 @@ from itertools import (
|
||||
)
|
||||
from logbook import FileHandler
|
||||
from mock import patch
|
||||
from numpy.testing import assert_array_equal
|
||||
import operator
|
||||
from zipline.finance.blotter import ORDER_STATUS
|
||||
from zipline.utils import security_list
|
||||
@@ -311,3 +312,17 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None):
|
||||
'exchange': 'TEST',
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def check_arrays(left, right, err_msg='', verbose=True):
|
||||
"""
|
||||
Wrapper around np.assert_array_equal that also verifies that inputs are
|
||||
ndarrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
np.assert_array_equal
|
||||
"""
|
||||
if type(left) != type(right):
|
||||
raise AssertionError("%s != %s" % (type(left), type(right)))
|
||||
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)
|
||||
|
||||
Reference in New Issue
Block a user