diff --git a/tests/pipeline/test_numerical_expression.py b/tests/pipeline/test_numerical_expression.py index 102f29d7..1175b0d1 100644 --- a/tests/pipeline/test_numerical_expression.py +++ b/tests/pipeline/test_numerical_expression.py @@ -153,7 +153,7 @@ class NumericalExpressionTestCase(TestCase): message = e.exception.args[0] expected = ( "Don't know how to compute datetime64[ns] + datetime64[ns].\n" - "Arithmetic operators are only supported on Factors of dtype " + "Arithmetic operators are only supported between Factors of dtype " "'float64'." ) self.assertEqual(message, expected) @@ -164,7 +164,7 @@ class NumericalExpressionTestCase(TestCase): message = e.exception.args[0] expected = ( "Don't know how to compute datetime64[ns] * datetime64[ns].\n" - "Arithmetic operators are only supported on Factors of dtype " + "Arithmetic operators are only supported between Factors of dtype " "'float64'." ) self.assertEqual(message, expected) @@ -178,8 +178,8 @@ class NumericalExpressionTestCase(TestCase): message = e.exception.args[0] expected = ( "Don't know how to compute float64 {sym} datetime64[ns].\n" - "Arithmetic operators are only supported on Factors of " - "dtype 'float64'." + "Arithmetic operators are only supported between Factors" + " of dtype 'float64'." ).format(sym=sym) self.assertEqual(message, expected) @@ -188,8 +188,8 @@ class NumericalExpressionTestCase(TestCase): message = e.exception.args[0] expected = ( "Don't know how to compute datetime64[ns] {sym} float64.\n" - "Arithmetic operators are only supported on Factors of " - "dtype 'float64'." + "Arithmetic operators are only supported between Factors" + " of dtype 'float64'." ).format(sym=sym) self.assertEqual(message, expected) diff --git a/zipline/pipeline/factors/factor.py b/zipline/pipeline/factors/factor.py index e9df98e8..d438d7fe 100644 --- a/zipline/pipeline/factors/factor.py +++ b/zipline/pipeline/factors/factor.py @@ -18,7 +18,7 @@ from zipline.pipeline.mixins import ( PositiveWindowLengthMixin, SingleInputMixin, ) -from zipline.pipeline.term import ComputableTerm, NotSpecified +from zipline.pipeline.term import ComputableTerm, NotSpecified, Term from zipline.pipeline.expression import ( BadBinaryOperator, COMPARISONS, @@ -140,7 +140,7 @@ def binop_return_dtype(op, left, right): elif left != float64_dtype or right != float64_dtype: raise TypeError( "Don't know how to compute {left} {op} {right}.\n" - "Arithmetic operators are only supported on Factors of " + "Arithmetic operators are only supported between Factors of " "dtype 'float64'.".format( left=left.name, op=op, @@ -188,7 +188,7 @@ def binary_operator(op): # inputs. Look up and call the appropriate reflected operator with # ourself as the input. return commuted_method_getter(other)(self) - elif isinstance(other, Factor): + elif isinstance(other, Term): if self is other: return return_type( "x_0 {op} x_0".format(op=op), @@ -204,7 +204,8 @@ def binary_operator(op): return return_type( "x_0 {op} ({constant})".format(op=op, constant=other), binds=(self,), - # Interpret numeric literals as floats. + # .dtype access is safe here because coerce_numbers_to_my_dtype + # will convert any input numbers to numpy equivalents. dtype=binop_return_dtype(op, self.dtype, other.dtype) ) raise BadBinaryOperator(op, self, other) diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index a56ddad2..197f60c8 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -19,7 +19,7 @@ from zipline.pipeline.mixins import ( PositiveWindowLengthMixin, SingleInputMixin, ) -from zipline.pipeline.term import ComputableTerm +from zipline.pipeline.term import ComputableTerm, Term from zipline.pipeline.expression import ( BadBinaryOperator, FILTER_BINOPS, @@ -67,7 +67,9 @@ def binary_operator(op): # merging of inputs. Look up and call the appropriate # right-binding operator with ourself as the input. return commuted_method_getter(other)(self) - elif isinstance(other, Filter): + elif isinstance(other, Term): + if other.dtype != bool_dtype: + raise BadBinaryOperator(op, self, other) if self is other: return NumExprFilter.create( "x_0 {op} x_0".format(op=op),