diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index ba5485bc..a8347776 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -236,17 +236,10 @@ class SimplePipelineEngine(object): """ Load mask and mask row labels for term. """ - mask = term.mask + mask = term.mask if not term.atomic else self._root_mask_term offset = graph.extra_rows[mask] - graph.extra_rows[term] return workspace[mask][offset:], dates[offset:] - def _mask_and_dates_for_atomic_terms(self, terms, workspace, graph, dates): - max_extra_rows = max(graph.extra_rows[term] for term in terms) - - mask = self._root_mask_term - offset = graph.extra_rows[mask] - max_extra_rows - return workspace[mask][offset:], dates[offset:] - @staticmethod def _inputs_for_term(term, workspace, graph): """ @@ -281,10 +274,14 @@ class SimplePipelineEngine(object): out.append(input_data) return out - def _atomic_terms_for_loader(self, graph, loader): + def _similar_atomic_terms(self, graph, atomic_term): loader_dispatch = self.loader_dispatch + loader = loader_dispatch(atomic_term) + extra_rows = graph.extra_rows[atomic_term] + for term in graph.atomic_terms: - if loader_dispatch(term) == loader: + if (loader_dispatch(term) == loader + and graph.extra_rows[term] == extra_rows): yield term def loader_dispatch(self, term): @@ -333,13 +330,18 @@ class SimplePipelineEngine(object): if term in workspace: continue + # Asset labels are always the same, but date labels vary by how + # many extra rows are needed. + mask, mask_dates = self._mask_and_dates_for_term( + term, workspace, graph, dates + ) + if term.atomic: - loader = loader_dispatch(term) - to_load = sorted(self._atomic_terms_for_loader(graph, loader), - key=lambda t: t.dataset) - mask, mask_dates = self._mask_and_dates_for_atomic_terms( - to_load, workspace, graph, dates, + to_load = sorted( + self._similar_atomic_terms(graph, term), + key=lambda t: t.dataset ) + loader = loader_dispatch(term) loaded = loader.load_adjusted_array( to_load, mask_dates, assets, mask, ) @@ -347,11 +349,6 @@ class SimplePipelineEngine(object): for loaded_term, adj_array in zip_longest(to_load, loaded): workspace[loaded_term] = adj_array else: - # Asset labels are always the same, but date labels vary by how - # many extra rows are needed. - mask, mask_dates = self._mask_and_dates_for_term( - term, workspace, graph, dates - ) workspace[term] = term._compute( self._inputs_for_term(term, workspace, graph), mask_dates,