From 5da03d2df5946199d14a5019816ebff9475958d1 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Mon, 3 Aug 2015 11:59:11 -0400 Subject: [PATCH] BUG: Make NumExprFilter return ndarray. - Previously it was returning a DataFrame because of how we applied an & with a DataFrame mask. The error was masked by the fact that `np.assert_array_equal` coerces inputs to arrays before comparing. - Added `zp.utils.test_utils.check_arrays`, which checks type equality before calling `np.assert_array_equal`. --- tests/modelling/test_numerical_expression.py | 4 ++-- zipline/modelling/filter.py | 5 ++--- zipline/utils/test_utils.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/modelling/test_numerical_expression.py b/tests/modelling/test_numerical_expression.py index 1afd2d20..04d52159 100644 --- a/tests/modelling/test_numerical_expression.py +++ b/tests/modelling/test_numerical_expression.py @@ -18,7 +18,6 @@ from numpy import ( isnan, zeros, ) -from numpy.testing import assert_array_equal from pandas import ( DataFrame, date_range, @@ -30,6 +29,7 @@ from zipline.modelling.expression import ( NUMEXPR_MATH_FUNCS, ) from zipline.modelling.factor import TestingFactor +from zipline.utils.test_utils import check_arrays class F(TestingFactor): @@ -67,7 +67,7 @@ class NumericalExpressionTestCase(TestCase): [self.fake_raw_data[input_] for input_ in expr.inputs], self.mask, ) - assert_array_equal(result, full((5, 5), expected)) + check_arrays(result, expected) def check_constant_output(self, expr, expected): self.assertFalse(isnan(expected)) diff --git a/zipline/modelling/filter.py b/zipline/modelling/filter.py index d6ede154..53a7f43b 100644 --- a/zipline/modelling/filter.py +++ b/zipline/modelling/filter.py @@ -128,11 +128,10 @@ class NumExprFilter(NumericalExpression, Filter): """ Compute our result with numexpr, then apply `mask`. """ - numexpr_result = super(NumExprFilter, self).compute_from_arrays( + return super(NumExprFilter, self).compute_from_arrays( arrays, mask, - ) - return numexpr_result & mask + ) & mask.values class PercentileFilter(SingleInputMixin, Filter): diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 9df753fc..79f0ffa5 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -4,6 +4,7 @@ from itertools import ( ) from logbook import FileHandler from mock import patch +from numpy.testing import assert_array_equal import operator from zipline.finance.blotter import ORDER_STATUS from zipline.utils import security_list @@ -311,3 +312,17 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None): 'exchange': 'TEST', } ) + + +def check_arrays(left, right, err_msg='', verbose=True): + """ + Wrapper around np.assert_array_equal that also verifies that inputs are + ndarrays. + + See Also + -------- + np.assert_array_equal + """ + if type(left) != type(right): + raise AssertionError("%s != %s" % (type(left), type(right))) + return assert_array_equal(left, right, err_msg=err_msg, verbose=True)