From 99de89c8172b9cc068588525e367a1ae3bec045d Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Fri, 9 Oct 2015 18:17:11 -0400 Subject: [PATCH] PERF: Don't recalc similar atomic terms --- etc/requirements.txt | 2 +- zipline/pipeline/engine.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/etc/requirements.txt b/etc/requirements.txt index 7125a63f..5c00af0a 100644 --- a/etc/requirements.txt +++ b/etc/requirements.txt @@ -47,7 +47,7 @@ bcolz==0.10.0 click==4.0.0 # FUNctional programming utilities -toolz==0.7.2 +toolz==0.7.4 # Asset writer and finder sqlalchemy==1.0.8 diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index 8abccb83..64cc1306 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -13,12 +13,13 @@ from six import ( ) from six.moves import zip_longest from numpy import array - from pandas import ( DataFrame, date_range, MultiIndex, ) +from toolz import groupby, juxt +from toolz.curried.operator import getitem from zipline.lib.adjusted_array import ensure_ndarray from zipline.errors import NoFurtherDataError @@ -274,17 +275,10 @@ class SimplePipelineEngine(object): out.append(input_data) return out - def _similar_atomic_terms(self, graph, atomic_term): - loader_dispatch = self.loader_dispatch - loader = loader_dispatch(atomic_term) - extra_rows = graph.extra_rows[atomic_term] - - for term in graph.atomic_terms: - if (loader_dispatch(term) == loader - and graph.extra_rows[term] == extra_rows): - yield term - def loader_dispatch(self, term): + # AssetExists is one of the atomic terms in the graph, so we look up + # a loader here when grouping by loader, but since it's already in the + # workspace, we don't actually use that group. if term is AssetExists(): return None @@ -322,6 +316,11 @@ class SimplePipelineEngine(object): # Copy the supplied initial workspace so we don't mutate it in place. workspace = initial_workspace.copy() + # If atomic terms share the same loader and extra_rows, load them all + # together. + atomic_group_key = juxt(loader_dispatch, getitem(graph.extra_rows)) + atomic_groups = groupby(atomic_group_key, graph.atomic_terms) + for term in graph.ordered(): # `term` may have been supplied in `initial_workspace`, and in the # future we may pre-compute atomic terms coming from the same @@ -338,7 +337,7 @@ class SimplePipelineEngine(object): if term.atomic: to_load = sorted( - self._similar_atomic_terms(graph, term), + atomic_groups[atomic_group_key(term)], key=lambda t: t.dataset ) loader = loader_dispatch(term)