mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 12:34:52 +08:00
BUG: NumericalExpressions fail to merge with too many inputs
This commit is contained in:
@@ -8,7 +8,9 @@ from operator import (
|
||||
methodcaller,
|
||||
mul,
|
||||
ne,
|
||||
sub,
|
||||
)
|
||||
from string import ascii_uppercase
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy
|
||||
@@ -147,6 +149,41 @@ class NumericalExpressionTestCase(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
(f > f) > f
|
||||
|
||||
def test_many_inputs(self):
|
||||
"""
|
||||
Test adding NumericalExpressions with >10 inputs.
|
||||
"""
|
||||
# Create an initial NumericalExpression by adding two factors together.
|
||||
f = self.f
|
||||
expr = f + f
|
||||
|
||||
self.fake_raw_data = {f: full((5, 5), 0, float)}
|
||||
expected = 0
|
||||
|
||||
# Alternate between adding and subtracting factors. Because subtraction
|
||||
# is not commutative, this ensures that we are combining factors in the
|
||||
# correct order.
|
||||
ops = (add, sub)
|
||||
|
||||
for i, name in enumerate(ascii_uppercase):
|
||||
op = ops[i % 2]
|
||||
NewFactor = type(
|
||||
name,
|
||||
(Factor,),
|
||||
dict(dtype=float64_dtype, inputs=(), window_length=0),
|
||||
)
|
||||
new_factor = NewFactor()
|
||||
|
||||
# Again we need a NumericalExpression, so add two factors together.
|
||||
new_expr = new_factor + new_factor
|
||||
self.fake_raw_data[new_factor] = full((5, 5), i + 1, float)
|
||||
expr = op(expr, new_expr)
|
||||
|
||||
# Double the expected output since each factor is counted twice.
|
||||
expected = op(expected, (i + 1) * 2)
|
||||
|
||||
self.check_output(expr, full((5, 5), expected, float))
|
||||
|
||||
def test_combine_datetimes(self):
|
||||
with self.assertRaises(TypeError) as e:
|
||||
self.d + self.d
|
||||
|
||||
@@ -248,7 +248,18 @@ class NumericalExpression(ComputableTerm):
|
||||
new_inputs.
|
||||
"""
|
||||
expr = self._expr
|
||||
for idx, input_ in enumerate(self.inputs):
|
||||
|
||||
# If we have 11+ variables, some of our variable names may be
|
||||
# substrings of other variable names. For example, we might have x_1,
|
||||
# x_10, and x_100. By enumerating in reverse order, we ensure that
|
||||
# every variable name which is a substring of another variable name is
|
||||
# processed after the variable of which it is a substring. This
|
||||
# guarantees that the substitution of any given variable index only
|
||||
# ever affects exactly its own index. For example, if we have variables
|
||||
# with indices going up to 100, we will process all of the x_1xx names
|
||||
# before x_1x, which will be before x_1, so the substitution of x_1
|
||||
# will not affect x_1x, which will not affect x_1xx.
|
||||
for idx, input_ in reversed(list(enumerate(self.inputs))):
|
||||
old_varname = "x_%d" % idx
|
||||
# Temporarily rebind to x_temp_N so that we don't overwrite the
|
||||
# same value multiple times.
|
||||
|
||||
Reference in New Issue
Block a user