MAINT: Group by loader and extra_rows

so that the mask and dates are the same for all the columns
the loader is loading at a time.
This commit is contained in:
Richard Frank
2015-10-06 17:53:39 -04:00
parent ba0542a641
commit 7bd6b69a89
+17 -20
View File
@@ -236,17 +236,10 @@ class SimplePipelineEngine(object):
"""
Load mask and mask row labels for term.
"""
mask = term.mask
mask = term.mask if not term.atomic else self._root_mask_term
offset = graph.extra_rows[mask] - graph.extra_rows[term]
return workspace[mask][offset:], dates[offset:]
def _mask_and_dates_for_atomic_terms(self, terms, workspace, graph, dates):
max_extra_rows = max(graph.extra_rows[term] for term in terms)
mask = self._root_mask_term
offset = graph.extra_rows[mask] - max_extra_rows
return workspace[mask][offset:], dates[offset:]
@staticmethod
def _inputs_for_term(term, workspace, graph):
"""
@@ -281,10 +274,14 @@ class SimplePipelineEngine(object):
out.append(input_data)
return out
def _atomic_terms_for_loader(self, graph, loader):
def _similar_atomic_terms(self, graph, atomic_term):
loader_dispatch = self.loader_dispatch
loader = loader_dispatch(atomic_term)
extra_rows = graph.extra_rows[atomic_term]
for term in graph.atomic_terms:
if loader_dispatch(term) == loader:
if (loader_dispatch(term) == loader
and graph.extra_rows[term] == extra_rows):
yield term
def loader_dispatch(self, term):
@@ -333,13 +330,18 @@ class SimplePipelineEngine(object):
if term in workspace:
continue
# Asset labels are always the same, but date labels vary by how
# many extra rows are needed.
mask, mask_dates = self._mask_and_dates_for_term(
term, workspace, graph, dates
)
if term.atomic:
loader = loader_dispatch(term)
to_load = sorted(self._atomic_terms_for_loader(graph, loader),
key=lambda t: t.dataset)
mask, mask_dates = self._mask_and_dates_for_atomic_terms(
to_load, workspace, graph, dates,
to_load = sorted(
self._similar_atomic_terms(graph, term),
key=lambda t: t.dataset
)
loader = loader_dispatch(term)
loaded = loader.load_adjusted_array(
to_load, mask_dates, assets, mask,
)
@@ -347,11 +349,6 @@ class SimplePipelineEngine(object):
for loaded_term, adj_array in zip_longest(to_load, loaded):
workspace[loaded_term] = adj_array
else:
# Asset labels are always the same, but date labels vary by how
# many extra rows are needed.
mask, mask_dates = self._mask_and_dates_for_term(
term, workspace, graph, dates
)
workspace[term] = term._compute(
self._inputs_for_term(term, workspace, graph),
mask_dates,