mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 07:02:48 +08:00
PERF: Don't recalc similar atomic terms
This commit is contained in:
@@ -47,7 +47,7 @@ bcolz==0.10.0
|
||||
click==4.0.0
|
||||
|
||||
# FUNctional programming utilities
|
||||
toolz==0.7.2
|
||||
toolz==0.7.4
|
||||
|
||||
# Asset writer and finder
|
||||
sqlalchemy==1.0.8
|
||||
|
||||
+11
-12
@@ -13,12 +13,13 @@ from six import (
|
||||
)
|
||||
from six.moves import zip_longest
|
||||
from numpy import array
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
date_range,
|
||||
MultiIndex,
|
||||
)
|
||||
from toolz import groupby, juxt
|
||||
from toolz.curried.operator import getitem
|
||||
|
||||
from zipline.lib.adjusted_array import ensure_ndarray
|
||||
from zipline.errors import NoFurtherDataError
|
||||
@@ -274,17 +275,10 @@ class SimplePipelineEngine(object):
|
||||
out.append(input_data)
|
||||
return out
|
||||
|
||||
def _similar_atomic_terms(self, graph, atomic_term):
|
||||
loader_dispatch = self.loader_dispatch
|
||||
loader = loader_dispatch(atomic_term)
|
||||
extra_rows = graph.extra_rows[atomic_term]
|
||||
|
||||
for term in graph.atomic_terms:
|
||||
if (loader_dispatch(term) == loader
|
||||
and graph.extra_rows[term] == extra_rows):
|
||||
yield term
|
||||
|
||||
def loader_dispatch(self, term):
|
||||
# AssetExists is one of the atomic terms in the graph, so we look up
|
||||
# a loader here when grouping by loader, but since it's already in the
|
||||
# workspace, we don't actually use that group.
|
||||
if term is AssetExists():
|
||||
return None
|
||||
|
||||
@@ -322,6 +316,11 @@ class SimplePipelineEngine(object):
|
||||
# Copy the supplied initial workspace so we don't mutate it in place.
|
||||
workspace = initial_workspace.copy()
|
||||
|
||||
# If atomic terms share the same loader and extra_rows, load them all
|
||||
# together.
|
||||
atomic_group_key = juxt(loader_dispatch, getitem(graph.extra_rows))
|
||||
atomic_groups = groupby(atomic_group_key, graph.atomic_terms)
|
||||
|
||||
for term in graph.ordered():
|
||||
# `term` may have been supplied in `initial_workspace`, and in the
|
||||
# future we may pre-compute atomic terms coming from the same
|
||||
@@ -338,7 +337,7 @@ class SimplePipelineEngine(object):
|
||||
|
||||
if term.atomic:
|
||||
to_load = sorted(
|
||||
self._similar_atomic_terms(graph, term),
|
||||
atomic_groups[atomic_group_key(term)],
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = loader_dispatch(term)
|
||||
|
||||
Reference in New Issue
Block a user