ENH: prune the graph based on the initial workspace

This commit is contained in:
Joe Jevnik
2016-10-05 17:26:25 -04:00
parent e3e4ad2735
commit 0123bb8a97
5 changed files with 78 additions and 38 deletions
+25 -17
View File
@@ -66,6 +66,7 @@ from zipline.pipeline.term import InputDates
from zipline.testing import (
AssetID,
AssetIDPlusDay,
ExplodingObject,
check_arrays,
make_alternating_boolean_array,
make_cascading_boolean_array,
@@ -1320,18 +1321,11 @@ class StringColumnTestCase(WithSeededRandomPipelineEngine,
class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
def make_engine(self, populate_initial_workspace):
return SimplePipelineEngine(
lambda column: self.loader,
self.dates,
self.asset_finder,
populate_initial_workspace=populate_initial_workspace,
)
def test_populate_default_workspace(self):
column = USEquityPricing.low
base_term = column.latest
term = base_term + 1
term = (base_term + 1).alias('term')
composed_term = term + 1
column_value = self.constants[column]
precomputed_value = -column_value
@@ -1343,25 +1337,39 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
return assoc(
initial_workspace,
term,
full((len(dates), len(assets)), precomputed_value),
full(
(len(dates), len(assets)),
precomputed_value,
dtype=float64,
),
)
# I resisted the urge to use ``make_engine`` as a decorator here
# because Scott would have yelled at me.
engine = self.make_engine(populate_initial_workspace)
def dispatcher(column):
if column is base_term:
# the base_term should never be loaded, its initial refcount
# should be zero
return ExplodingObject()
return self.loader
engine = SimplePipelineEngine(
dispatcher,
self.dates,
self.asset_finder,
populate_initial_workspace=populate_initial_workspace,
)
results = engine.run_pipeline(
Pipeline({
'term-in-initial-workspace': term,
'term-not-in-initial-workspace': base_term,
'term': term,
'composed_term': composed_term,
}),
self.dates[0],
self.dates[-1],
)
self.assertTrue(
(results['term-in-initial-workspace'] == precomputed_value).all(),
(results['term'] == precomputed_value).all(),
)
self.assertTrue(
(results['term-not-in-initial-workspace'] == column_value).all(),
(results['composed_term'] == (precomputed_value + 1)).all(),
)
+8 -2
View File
@@ -195,8 +195,14 @@ class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
self.assertIn(SomeDataSet.bar, resolution_order)
self.assertIn(SomeFactor(), resolution_order)
self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4)
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
self.assertEqual(
graph.graph.node[SomeDataSet.foo]['extra_rows'],
4,
)
self.assertEqual(
graph.graph.node[SomeDataSet.bar]['extra_rows'],
4,
)
for foobar in gen_equivalent_factors():
check_output(self.make_execution_plan(to_dict([foobar])))
+1 -1
View File
@@ -362,7 +362,7 @@ class SimplePipelineEngine(object):
refcounts = graph.initial_refcounts(workspace)
for term in graph.ordered():
for term in graph.execution_order(refcounts):
# `term` may have been supplied in `initial_workspace`, and in the
# future we may pre-compute loadable terms coming from the same
# dataset. In either case, we will already have an entry for this
+41 -16
View File
@@ -16,7 +16,7 @@ class CyclicDependency(Exception):
pass
class TermGraph(DiGraph):
class TermGraph(object):
"""
An abstract representation of Pipeline Term dependencies.
@@ -44,7 +44,7 @@ class TermGraph(DiGraph):
ExecutionPlan
"""
def __init__(self, terms):
super(TermGraph, self).__init__()
self.graph = DiGraph()
self._frozen = False
parents = set()
@@ -54,7 +54,6 @@ class TermGraph(DiGraph):
assert not parents
self._outputs = terms
self._ordered = topological_sort(self)
# Mark that no more terms should be added to the graph.
self._frozen = True
@@ -79,11 +78,11 @@ class TermGraph(DiGraph):
parents.add(term)
self.add_node(term)
self.graph.add_node(term)
for dependency in term.dependencies:
self._add_to_graph(dependency, parents)
self.add_edge(dependency, term)
self.graph.add_edge(dependency, term)
parents.remove(term)
@@ -94,15 +93,25 @@ class TermGraph(DiGraph):
"""
return self._outputs
def execution_order(self, refcounts):
"""
Return a topologically-sorted iterator over the terms in ``self`` which
need to be computed.
"""
return iter(topological_sort(
self.graph.subgraph(
{term for term, refcount in refcounts.items() if refcount > 0},
),
))
def ordered(self):
"""
Return a topologically-sorted iterator over the terms in `self`.
"""
return iter(self._ordered)
return iter(topological_sort(self.graph))
@lazyval
def loadable_terms(self):
return tuple(term for term in self if isinstance(term, LoadableTerm))
return tuple(
term for term in self.graph if isinstance(term, LoadableTerm)
)
@lazyval
def jpeg(self):
@@ -132,15 +141,28 @@ class TermGraph(DiGraph):
nodes get one extra reference to ensure that they're still in the graph
at the end of execution.
"""
refcounts = self.out_degree()
refcounts = self.graph.out_degree()
for t in self.outputs.values():
refcounts[t] += 1
for t in initial_terms:
self.decref_dependencies(t, refcounts)
self._decref_recursive(t, refcounts, set())
return refcounts
def _decref_recursive(self, term, refcounts, garbage):
"""
Decrement terms recursivly to build the initial workspace.
"""
# Edges are tuple of (from, to).
for parent, _ in self.graph.in_edges([term]):
refcounts[parent] -= 1
# No one else depends on this term. Remove it from the
# workspace to conserve memory.
if refcounts[parent] == 0:
garbage.add(parent)
self._decref_recursive(parent, refcounts, garbage)
def decref_dependencies(self, term, refcounts):
"""
Decrement in-edges for ``term`` after computation.
@@ -159,7 +181,7 @@ class TermGraph(DiGraph):
"""
garbage = set()
# Edges are tuple of (from, to).
for parent, _ in self.in_edges([term]):
for parent, _ in self.graph.in_edges([term]):
refcounts[parent] -= 1
# No one else depends on this term. Remove it from the
# workspace to conserve memory.
@@ -167,6 +189,9 @@ class TermGraph(DiGraph):
garbage.add(parent)
return garbage
def __iter__(self):
return iter(self.graph)
class ExecutionPlan(TermGraph):
"""
@@ -326,7 +351,7 @@ class ExecutionPlan(TermGraph):
# How much bigger is the array for ``dep`` compared to ``term``?
# How much of that difference did I ask for.
(term, dep): (extra[dep] - extra[term]) - requested_extra_rows
for term in self
for term in self.graph
for dep, requested_extra_rows in term.dependencies.items()
}
@@ -366,14 +391,14 @@ class ExecutionPlan(TermGraph):
"""
return {
term: attrs['extra_rows']
for term, attrs in iteritems(self.node)
for term, attrs in iteritems(self.graph.node)
}
def _ensure_extra_rows(self, term, N):
"""
Ensure that we're going to compute at least N extra rows of `term`.
"""
attrs = self.node[term]
attrs = self.graph.node[term]
attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0))
def mask_and_dates_for_term(self,
+3 -2
View File
@@ -115,13 +115,14 @@ def _render(g, out, format_, include_asset_exists=False):
add_term_node(f, term)
# Write intermediate results.
for term in filter_nodes(include_asset_exists, topological_sort(g)):
for term in filter_nodes(include_asset_exists,
topological_sort(g.graph)):
if term in in_nodes or term in out_nodes:
continue
add_term_node(f, term)
# Write edges
for source, dest in g.edges():
for source, dest in g.graph.edges():
if source is AssetExists() and not include_asset_exists:
continue
add_edge(f, id(source), id(dest))