mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 00:49:38 +08:00
MAINT: Group by loader and extra_rows
so that the mask and dates are the same for all the columns the loader is loading at a time.
This commit is contained in:
+17
-20
@@ -236,17 +236,10 @@ class SimplePipelineEngine(object):
|
||||
"""
|
||||
Load mask and mask row labels for term.
|
||||
"""
|
||||
mask = term.mask
|
||||
mask = term.mask if not term.atomic else self._root_mask_term
|
||||
offset = graph.extra_rows[mask] - graph.extra_rows[term]
|
||||
return workspace[mask][offset:], dates[offset:]
|
||||
|
||||
def _mask_and_dates_for_atomic_terms(self, terms, workspace, graph, dates):
|
||||
max_extra_rows = max(graph.extra_rows[term] for term in terms)
|
||||
|
||||
mask = self._root_mask_term
|
||||
offset = graph.extra_rows[mask] - max_extra_rows
|
||||
return workspace[mask][offset:], dates[offset:]
|
||||
|
||||
@staticmethod
|
||||
def _inputs_for_term(term, workspace, graph):
|
||||
"""
|
||||
@@ -281,10 +274,14 @@ class SimplePipelineEngine(object):
|
||||
out.append(input_data)
|
||||
return out
|
||||
|
||||
def _atomic_terms_for_loader(self, graph, loader):
|
||||
def _similar_atomic_terms(self, graph, atomic_term):
|
||||
loader_dispatch = self.loader_dispatch
|
||||
loader = loader_dispatch(atomic_term)
|
||||
extra_rows = graph.extra_rows[atomic_term]
|
||||
|
||||
for term in graph.atomic_terms:
|
||||
if loader_dispatch(term) == loader:
|
||||
if (loader_dispatch(term) == loader
|
||||
and graph.extra_rows[term] == extra_rows):
|
||||
yield term
|
||||
|
||||
def loader_dispatch(self, term):
|
||||
@@ -333,13 +330,18 @@ class SimplePipelineEngine(object):
|
||||
if term in workspace:
|
||||
continue
|
||||
|
||||
# Asset labels are always the same, but date labels vary by how
|
||||
# many extra rows are needed.
|
||||
mask, mask_dates = self._mask_and_dates_for_term(
|
||||
term, workspace, graph, dates
|
||||
)
|
||||
|
||||
if term.atomic:
|
||||
loader = loader_dispatch(term)
|
||||
to_load = sorted(self._atomic_terms_for_loader(graph, loader),
|
||||
key=lambda t: t.dataset)
|
||||
mask, mask_dates = self._mask_and_dates_for_atomic_terms(
|
||||
to_load, workspace, graph, dates,
|
||||
to_load = sorted(
|
||||
self._similar_atomic_terms(graph, term),
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = loader_dispatch(term)
|
||||
loaded = loader.load_adjusted_array(
|
||||
to_load, mask_dates, assets, mask,
|
||||
)
|
||||
@@ -347,11 +349,6 @@ class SimplePipelineEngine(object):
|
||||
for loaded_term, adj_array in zip_longest(to_load, loaded):
|
||||
workspace[loaded_term] = adj_array
|
||||
else:
|
||||
# Asset labels are always the same, but date labels vary by how
|
||||
# many extra rows are needed.
|
||||
mask, mask_dates = self._mask_and_dates_for_term(
|
||||
term, workspace, graph, dates
|
||||
)
|
||||
workspace[term] = term._compute(
|
||||
self._inputs_for_term(term, workspace, graph),
|
||||
mask_dates,
|
||||
|
||||
Reference in New Issue
Block a user