diff --git a/tests/pipeline/test_numerical_expression.py b/tests/pipeline/test_numerical_expression.py index 1175b0d1..040ed124 100644 --- a/tests/pipeline/test_numerical_expression.py +++ b/tests/pipeline/test_numerical_expression.py @@ -8,7 +8,9 @@ from operator import ( methodcaller, mul, ne, + sub, ) +from string import ascii_uppercase from unittest import TestCase import numpy @@ -147,6 +149,41 @@ class NumericalExpressionTestCase(TestCase): with self.assertRaises(TypeError): (f > f) > f + def test_many_inputs(self): + """ + Test adding NumericalExpressions with >10 inputs. + """ + # Create an initial NumericalExpression by adding two factors together. + f = self.f + expr = f + f + + self.fake_raw_data = {f: full((5, 5), 0, float)} + expected = 0 + + # Alternate between adding and subtracting factors. Because subtraction + # is not commutative, this ensures that we are combining factors in the + # correct order. + ops = (add, sub) + + for i, name in enumerate(ascii_uppercase): + op = ops[i % 2] + NewFactor = type( + name, + (Factor,), + dict(dtype=float64_dtype, inputs=(), window_length=0), + ) + new_factor = NewFactor() + + # Again we need a NumericalExpression, so add two factors together. + new_expr = new_factor + new_factor + self.fake_raw_data[new_factor] = full((5, 5), i + 1, float) + expr = op(expr, new_expr) + + # Double the expected output since each factor is counted twice. + expected = op(expected, (i + 1) * 2) + + self.check_output(expr, full((5, 5), expected, float)) + def test_combine_datetimes(self): with self.assertRaises(TypeError) as e: self.d + self.d diff --git a/zipline/pipeline/expression.py b/zipline/pipeline/expression.py index 50d45a5e..1f7b976d 100644 --- a/zipline/pipeline/expression.py +++ b/zipline/pipeline/expression.py @@ -248,7 +248,18 @@ class NumericalExpression(ComputableTerm): new_inputs. """ expr = self._expr - for idx, input_ in enumerate(self.inputs): + + # If we have 11+ variables, some of our variable names may be + # substrings of other variable names. For example, we might have x_1, + # x_10, and x_100. By enumerating in reverse order, we ensure that + # every variable name which is a substring of another variable name is + # processed after the variable of which it is a substring. This + # guarantees that the substitution of any given variable index only + # ever affects exactly its own index. For example, if we have variables + # with indices going up to 100, we will process all of the x_1xx names + # before x_1x, which will be before x_1, so the substitution of x_1 + # will not affect x_1x, which will not affect x_1xx. + for idx, input_ in reversed(list(enumerate(self.inputs))): old_varname = "x_%d" % idx # Temporarily rebind to x_temp_N so that we don't overwrite the # same value multiple times.