mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 17:47:24 +08:00
BUG: Allow Filter comparisons with AssetExists.
Allow comparisons like SomeFilter() & AssetExists(). Previously such comparisons would fail because & and | on Filters explicitly checked that the other side of the operator was also a Filter. We now only enforce that the other side of the expression is a Term with a dtype of bool_.
This commit is contained in:
@@ -153,7 +153,7 @@ class NumericalExpressionTestCase(TestCase):
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] + datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of dtype "
|
||||
"Arithmetic operators are only supported between Factors of dtype "
|
||||
"'float64'."
|
||||
)
|
||||
self.assertEqual(message, expected)
|
||||
@@ -164,7 +164,7 @@ class NumericalExpressionTestCase(TestCase):
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] * datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of dtype "
|
||||
"Arithmetic operators are only supported between Factors of dtype "
|
||||
"'float64'."
|
||||
)
|
||||
self.assertEqual(message, expected)
|
||||
@@ -178,8 +178,8 @@ class NumericalExpressionTestCase(TestCase):
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute float64 {sym} datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"dtype 'float64'."
|
||||
"Arithmetic operators are only supported between Factors"
|
||||
" of dtype 'float64'."
|
||||
).format(sym=sym)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
@@ -188,8 +188,8 @@ class NumericalExpressionTestCase(TestCase):
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] {sym} float64.\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"dtype 'float64'."
|
||||
"Arithmetic operators are only supported between Factors"
|
||||
" of dtype 'float64'."
|
||||
).format(sym=sym)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from zipline.pipeline.mixins import (
|
||||
PositiveWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.term import ComputableTerm, NotSpecified
|
||||
from zipline.pipeline.term import ComputableTerm, NotSpecified, Term
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
COMPARISONS,
|
||||
@@ -140,7 +140,7 @@ def binop_return_dtype(op, left, right):
|
||||
elif left != float64_dtype or right != float64_dtype:
|
||||
raise TypeError(
|
||||
"Don't know how to compute {left} {op} {right}.\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"Arithmetic operators are only supported between Factors of "
|
||||
"dtype 'float64'.".format(
|
||||
left=left.name,
|
||||
op=op,
|
||||
@@ -188,7 +188,7 @@ def binary_operator(op):
|
||||
# inputs. Look up and call the appropriate reflected operator with
|
||||
# ourself as the input.
|
||||
return commuted_method_getter(other)(self)
|
||||
elif isinstance(other, Factor):
|
||||
elif isinstance(other, Term):
|
||||
if self is other:
|
||||
return return_type(
|
||||
"x_0 {op} x_0".format(op=op),
|
||||
@@ -204,7 +204,8 @@ def binary_operator(op):
|
||||
return return_type(
|
||||
"x_0 {op} ({constant})".format(op=op, constant=other),
|
||||
binds=(self,),
|
||||
# Interpret numeric literals as floats.
|
||||
# .dtype access is safe here because coerce_numbers_to_my_dtype
|
||||
# will convert any input numbers to numpy equivalents.
|
||||
dtype=binop_return_dtype(op, self.dtype, other.dtype)
|
||||
)
|
||||
raise BadBinaryOperator(op, self, other)
|
||||
|
||||
@@ -19,7 +19,7 @@ from zipline.pipeline.mixins import (
|
||||
PositiveWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.term import ComputableTerm
|
||||
from zipline.pipeline.term import ComputableTerm, Term
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
FILTER_BINOPS,
|
||||
@@ -67,7 +67,9 @@ def binary_operator(op):
|
||||
# merging of inputs. Look up and call the appropriate
|
||||
# right-binding operator with ourself as the input.
|
||||
return commuted_method_getter(other)(self)
|
||||
elif isinstance(other, Filter):
|
||||
elif isinstance(other, Term):
|
||||
if other.dtype != bool_dtype:
|
||||
raise BadBinaryOperator(op, self, other)
|
||||
if self is other:
|
||||
return NumExprFilter.create(
|
||||
"x_0 {op} x_0".format(op=op),
|
||||
|
||||
Reference in New Issue
Block a user