BUG: NumericalExpressions fail to merge with too many inputs

This commit is contained in:
dmichalowicz
2016-03-21 15:54:52 -04:00
parent 4164ffdcb0
commit 7e83a8df5f
2 changed files with 49 additions and 1 deletions
@@ -8,7 +8,9 @@ from operator import (
methodcaller,
mul,
ne,
sub,
)
from string import ascii_uppercase
from unittest import TestCase
import numpy
@@ -147,6 +149,41 @@ class NumericalExpressionTestCase(TestCase):
with self.assertRaises(TypeError):
(f > f) > f
def test_many_inputs(self):
"""
Test adding NumericalExpressions with >10 inputs.
"""
# Create an initial NumericalExpression by adding two factors together.
f = self.f
expr = f + f
self.fake_raw_data = {f: full((5, 5), 0, float)}
expected = 0
# Alternate between adding and subtracting factors. Because subtraction
# is not commutative, this ensures that we are combining factors in the
# correct order.
ops = (add, sub)
for i, name in enumerate(ascii_uppercase):
op = ops[i % 2]
NewFactor = type(
name,
(Factor,),
dict(dtype=float64_dtype, inputs=(), window_length=0),
)
new_factor = NewFactor()
# Again we need a NumericalExpression, so add two factors together.
new_expr = new_factor + new_factor
self.fake_raw_data[new_factor] = full((5, 5), i + 1, float)
expr = op(expr, new_expr)
# Double the expected output since each factor is counted twice.
expected = op(expected, (i + 1) * 2)
self.check_output(expr, full((5, 5), expected, float))
def test_combine_datetimes(self):
with self.assertRaises(TypeError) as e:
self.d + self.d
+12 -1
View File
@@ -248,7 +248,18 @@ class NumericalExpression(ComputableTerm):
new_inputs.
"""
expr = self._expr
for idx, input_ in enumerate(self.inputs):
# If we have 11+ variables, some of our variable names may be
# substrings of other variable names. For example, we might have x_1,
# x_10, and x_100. By enumerating in reverse order, we ensure that
# every variable name which is a substring of another variable name is
# processed after the variable of which it is a substring. This
# guarantees that the substitution of any given variable index only
# ever affects exactly its own index. For example, if we have variables
# with indices going up to 100, we will process all of the x_1xx names
# before x_1x, which will be before x_1, so the substitution of x_1
# will not affect x_1x, which will not affect x_1xx.
for idx, input_ in reversed(list(enumerate(self.inputs))):
old_varname = "x_%d" % idx
# Temporarily rebind to x_temp_N so that we don't overwrite the
# same value multiple times.