diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index 461f5beb..a3e090f6 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -410,7 +410,7 @@ class FilterTestCase(BasePipelineTestCase): strictly_true_filter = StrictlyTrueFilter( inputs=(InputFilter(), ), - window_length=self.default_shape[0] + window_length=(self.default_shape[0]-1) ) results = self.run_graph( @@ -418,11 +418,18 @@ class FilterTestCase(BasePipelineTestCase): initial_workspace={InputFilter(): data} ) - expected_result = full(self.default_shape[1], False, dtype=bool) - expected_result[0] = True + expected_result_0 = full(self.default_shape[1], False, dtype=bool) + expected_result_0[0] = True check_arrays( - reshape(results['Filter'], expected_result.shape[0]), - expected_result, + reshape(results['Filter'][0], expected_result_0.shape[0]), + expected_result_0, + ) + + expected_result_1 = full(self.default_shape[1], False, dtype=bool) + expected_result_1[:2] = True + check_arrays( + reshape(results['Filter'][1], expected_result_1.shape[0]), + expected_result_1, ) @parameter_space(factor_len=[2, 3, 4])