BUG: Make NumExprFilter return ndarray.

- Previously it was returning a DataFrame because of how we applied an &
  with a DataFrame mask.  The error was masked by the fact that
  `np.assert_array_equal` coerces inputs to arrays before comparing.

- Added `zp.utils.test_utils.check_arrays`, which checks type equality
  before calling `np.assert_array_equal`.
This commit is contained in:
Scott Sanderson
2015-08-03 11:59:11 -04:00
parent 67c56f768b
commit 5da03d2df5
3 changed files with 19 additions and 5 deletions
+2 -2
View File
@@ -18,7 +18,6 @@ from numpy import (
isnan,
zeros,
)
from numpy.testing import assert_array_equal
from pandas import (
DataFrame,
date_range,
@@ -30,6 +29,7 @@ from zipline.modelling.expression import (
NUMEXPR_MATH_FUNCS,
)
from zipline.modelling.factor import TestingFactor
from zipline.utils.test_utils import check_arrays
class F(TestingFactor):
@@ -67,7 +67,7 @@ class NumericalExpressionTestCase(TestCase):
[self.fake_raw_data[input_] for input_ in expr.inputs],
self.mask,
)
assert_array_equal(result, full((5, 5), expected))
check_arrays(result, expected)
def check_constant_output(self, expr, expected):
self.assertFalse(isnan(expected))
+2 -3
View File
@@ -128,11 +128,10 @@ class NumExprFilter(NumericalExpression, Filter):
"""
Compute our result with numexpr, then apply `mask`.
"""
numexpr_result = super(NumExprFilter, self).compute_from_arrays(
return super(NumExprFilter, self).compute_from_arrays(
arrays,
mask,
)
return numexpr_result & mask
) & mask.values
class PercentileFilter(SingleInputMixin, Filter):
+15
View File
@@ -4,6 +4,7 @@ from itertools import (
)
from logbook import FileHandler
from mock import patch
from numpy.testing import assert_array_equal
import operator
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils import security_list
@@ -311,3 +312,17 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None):
'exchange': 'TEST',
}
)
def check_arrays(left, right, err_msg='', verbose=True):
"""
Wrapper around np.assert_array_equal that also verifies that inputs are
ndarrays.
See Also
--------
np.assert_array_equal
"""
if type(left) != type(right):
raise AssertionError("%s != %s" % (type(left), type(right)))
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)