From 0123bb8a97186a1cd95f772c36b23044fc4812df Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Wed, 5 Oct 2016 17:26:25 -0400 Subject: [PATCH] ENH: prune the graph based on the initial workspace --- tests/pipeline/test_engine.py | 42 +++++++++++++++----------- tests/pipeline/test_term.py | 10 ++++-- zipline/pipeline/engine.py | 2 +- zipline/pipeline/graph.py | 57 +++++++++++++++++++++++++---------- zipline/pipeline/visualize.py | 5 +-- 5 files changed, 78 insertions(+), 38 deletions(-) diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index 23c843b6..0fa140b4 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -66,6 +66,7 @@ from zipline.pipeline.term import InputDates from zipline.testing import ( AssetID, AssetIDPlusDay, + ExplodingObject, check_arrays, make_alternating_boolean_array, make_cascading_boolean_array, @@ -1320,18 +1321,11 @@ class StringColumnTestCase(WithSeededRandomPipelineEngine, class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase): - def make_engine(self, populate_initial_workspace): - return SimplePipelineEngine( - lambda column: self.loader, - self.dates, - self.asset_finder, - populate_initial_workspace=populate_initial_workspace, - ) - def test_populate_default_workspace(self): column = USEquityPricing.low base_term = column.latest - term = base_term + 1 + term = (base_term + 1).alias('term') + composed_term = term + 1 column_value = self.constants[column] precomputed_value = -column_value @@ -1343,25 +1337,39 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase): return assoc( initial_workspace, term, - full((len(dates), len(assets)), precomputed_value), + full( + (len(dates), len(assets)), + precomputed_value, + dtype=float64, + ), ) - # I resisted the urge to use ``make_engine`` as a decorator here - # because Scott would have yelled at me. - engine = self.make_engine(populate_initial_workspace) + def dispatcher(column): + if column is base_term: + # the base_term should never be loaded, its initial refcount + # should be zero + return ExplodingObject() + return self.loader + + engine = SimplePipelineEngine( + dispatcher, + self.dates, + self.asset_finder, + populate_initial_workspace=populate_initial_workspace, + ) results = engine.run_pipeline( Pipeline({ - 'term-in-initial-workspace': term, - 'term-not-in-initial-workspace': base_term, + 'term': term, + 'composed_term': composed_term, }), self.dates[0], self.dates[-1], ) self.assertTrue( - (results['term-in-initial-workspace'] == precomputed_value).all(), + (results['term'] == precomputed_value).all(), ) self.assertTrue( - (results['term-not-in-initial-workspace'] == column_value).all(), + (results['composed_term'] == (precomputed_value + 1)).all(), ) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index c58dd4e0..39d507ca 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -195,8 +195,14 @@ class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase): self.assertIn(SomeDataSet.bar, resolution_order) self.assertIn(SomeFactor(), resolution_order) - self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4) - self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4) + self.assertEqual( + graph.graph.node[SomeDataSet.foo]['extra_rows'], + 4, + ) + self.assertEqual( + graph.graph.node[SomeDataSet.bar]['extra_rows'], + 4, + ) for foobar in gen_equivalent_factors(): check_output(self.make_execution_plan(to_dict([foobar]))) diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index e2a098e0..6c44703e 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -362,7 +362,7 @@ class SimplePipelineEngine(object): refcounts = graph.initial_refcounts(workspace) - for term in graph.ordered(): + for term in graph.execution_order(refcounts): # `term` may have been supplied in `initial_workspace`, and in the # future we may pre-compute loadable terms coming from the same # dataset. In either case, we will already have an entry for this diff --git a/zipline/pipeline/graph.py b/zipline/pipeline/graph.py index cef8e7a0..8919fcb6 100644 --- a/zipline/pipeline/graph.py +++ b/zipline/pipeline/graph.py @@ -16,7 +16,7 @@ class CyclicDependency(Exception): pass -class TermGraph(DiGraph): +class TermGraph(object): """ An abstract representation of Pipeline Term dependencies. @@ -44,7 +44,7 @@ class TermGraph(DiGraph): ExecutionPlan """ def __init__(self, terms): - super(TermGraph, self).__init__() + self.graph = DiGraph() self._frozen = False parents = set() @@ -54,7 +54,6 @@ class TermGraph(DiGraph): assert not parents self._outputs = terms - self._ordered = topological_sort(self) # Mark that no more terms should be added to the graph. self._frozen = True @@ -79,11 +78,11 @@ class TermGraph(DiGraph): parents.add(term) - self.add_node(term) + self.graph.add_node(term) for dependency in term.dependencies: self._add_to_graph(dependency, parents) - self.add_edge(dependency, term) + self.graph.add_edge(dependency, term) parents.remove(term) @@ -94,15 +93,25 @@ class TermGraph(DiGraph): """ return self._outputs + def execution_order(self, refcounts): + """ + Return a topologically-sorted iterator over the terms in ``self`` which + need to be computed. + """ + return iter(topological_sort( + self.graph.subgraph( + {term for term, refcount in refcounts.items() if refcount > 0}, + ), + )) + def ordered(self): - """ - Return a topologically-sorted iterator over the terms in `self`. - """ - return iter(self._ordered) + return iter(topological_sort(self.graph)) @lazyval def loadable_terms(self): - return tuple(term for term in self if isinstance(term, LoadableTerm)) + return tuple( + term for term in self.graph if isinstance(term, LoadableTerm) + ) @lazyval def jpeg(self): @@ -132,15 +141,28 @@ class TermGraph(DiGraph): nodes get one extra reference to ensure that they're still in the graph at the end of execution. """ - refcounts = self.out_degree() + refcounts = self.graph.out_degree() for t in self.outputs.values(): refcounts[t] += 1 for t in initial_terms: - self.decref_dependencies(t, refcounts) + self._decref_recursive(t, refcounts, set()) return refcounts + def _decref_recursive(self, term, refcounts, garbage): + """ + Decrement terms recursivly to build the initial workspace. + """ + # Edges are tuple of (from, to). + for parent, _ in self.graph.in_edges([term]): + refcounts[parent] -= 1 + # No one else depends on this term. Remove it from the + # workspace to conserve memory. + if refcounts[parent] == 0: + garbage.add(parent) + self._decref_recursive(parent, refcounts, garbage) + def decref_dependencies(self, term, refcounts): """ Decrement in-edges for ``term`` after computation. @@ -159,7 +181,7 @@ class TermGraph(DiGraph): """ garbage = set() # Edges are tuple of (from, to). - for parent, _ in self.in_edges([term]): + for parent, _ in self.graph.in_edges([term]): refcounts[parent] -= 1 # No one else depends on this term. Remove it from the # workspace to conserve memory. @@ -167,6 +189,9 @@ class TermGraph(DiGraph): garbage.add(parent) return garbage + def __iter__(self): + return iter(self.graph) + class ExecutionPlan(TermGraph): """ @@ -326,7 +351,7 @@ class ExecutionPlan(TermGraph): # How much bigger is the array for ``dep`` compared to ``term``? # How much of that difference did I ask for. (term, dep): (extra[dep] - extra[term]) - requested_extra_rows - for term in self + for term in self.graph for dep, requested_extra_rows in term.dependencies.items() } @@ -366,14 +391,14 @@ class ExecutionPlan(TermGraph): """ return { term: attrs['extra_rows'] - for term, attrs in iteritems(self.node) + for term, attrs in iteritems(self.graph.node) } def _ensure_extra_rows(self, term, N): """ Ensure that we're going to compute at least N extra rows of `term`. """ - attrs = self.node[term] + attrs = self.graph.node[term] attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0)) def mask_and_dates_for_term(self, diff --git a/zipline/pipeline/visualize.py b/zipline/pipeline/visualize.py index f0356b43..fdc34b70 100644 --- a/zipline/pipeline/visualize.py +++ b/zipline/pipeline/visualize.py @@ -115,13 +115,14 @@ def _render(g, out, format_, include_asset_exists=False): add_term_node(f, term) # Write intermediate results. - for term in filter_nodes(include_asset_exists, topological_sort(g)): + for term in filter_nodes(include_asset_exists, + topological_sort(g.graph)): if term in in_nodes or term in out_nodes: continue add_term_node(f, term) # Write edges - for source, dest in g.edges(): + for source, dest in g.graph.edges(): if source is AssetExists() and not include_asset_exists: continue add_edge(f, id(source), id(dest))