mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 14:20:48 +08:00
ENH: prune the graph based on the initial workspace
This commit is contained in:
@@ -66,6 +66,7 @@ from zipline.pipeline.term import InputDates
|
||||
from zipline.testing import (
|
||||
AssetID,
|
||||
AssetIDPlusDay,
|
||||
ExplodingObject,
|
||||
check_arrays,
|
||||
make_alternating_boolean_array,
|
||||
make_cascading_boolean_array,
|
||||
@@ -1320,18 +1321,11 @@ class StringColumnTestCase(WithSeededRandomPipelineEngine,
|
||||
|
||||
|
||||
class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
def make_engine(self, populate_initial_workspace):
|
||||
return SimplePipelineEngine(
|
||||
lambda column: self.loader,
|
||||
self.dates,
|
||||
self.asset_finder,
|
||||
populate_initial_workspace=populate_initial_workspace,
|
||||
)
|
||||
|
||||
def test_populate_default_workspace(self):
|
||||
column = USEquityPricing.low
|
||||
base_term = column.latest
|
||||
term = base_term + 1
|
||||
term = (base_term + 1).alias('term')
|
||||
composed_term = term + 1
|
||||
column_value = self.constants[column]
|
||||
precomputed_value = -column_value
|
||||
|
||||
@@ -1343,25 +1337,39 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
return assoc(
|
||||
initial_workspace,
|
||||
term,
|
||||
full((len(dates), len(assets)), precomputed_value),
|
||||
full(
|
||||
(len(dates), len(assets)),
|
||||
precomputed_value,
|
||||
dtype=float64,
|
||||
),
|
||||
)
|
||||
|
||||
# I resisted the urge to use ``make_engine`` as a decorator here
|
||||
# because Scott would have yelled at me.
|
||||
engine = self.make_engine(populate_initial_workspace)
|
||||
def dispatcher(column):
|
||||
if column is base_term:
|
||||
# the base_term should never be loaded, its initial refcount
|
||||
# should be zero
|
||||
return ExplodingObject()
|
||||
return self.loader
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
dispatcher,
|
||||
self.dates,
|
||||
self.asset_finder,
|
||||
populate_initial_workspace=populate_initial_workspace,
|
||||
)
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline({
|
||||
'term-in-initial-workspace': term,
|
||||
'term-not-in-initial-workspace': base_term,
|
||||
'term': term,
|
||||
'composed_term': composed_term,
|
||||
}),
|
||||
self.dates[0],
|
||||
self.dates[-1],
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
(results['term-in-initial-workspace'] == precomputed_value).all(),
|
||||
(results['term'] == precomputed_value).all(),
|
||||
)
|
||||
self.assertTrue(
|
||||
(results['term-not-in-initial-workspace'] == column_value).all(),
|
||||
(results['composed_term'] == (precomputed_value + 1)).all(),
|
||||
)
|
||||
|
||||
@@ -195,8 +195,14 @@ class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
|
||||
self.assertIn(SomeDataSet.bar, resolution_order)
|
||||
self.assertIn(SomeFactor(), resolution_order)
|
||||
|
||||
self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4)
|
||||
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
|
||||
self.assertEqual(
|
||||
graph.graph.node[SomeDataSet.foo]['extra_rows'],
|
||||
4,
|
||||
)
|
||||
self.assertEqual(
|
||||
graph.graph.node[SomeDataSet.bar]['extra_rows'],
|
||||
4,
|
||||
)
|
||||
|
||||
for foobar in gen_equivalent_factors():
|
||||
check_output(self.make_execution_plan(to_dict([foobar])))
|
||||
|
||||
@@ -362,7 +362,7 @@ class SimplePipelineEngine(object):
|
||||
|
||||
refcounts = graph.initial_refcounts(workspace)
|
||||
|
||||
for term in graph.ordered():
|
||||
for term in graph.execution_order(refcounts):
|
||||
# `term` may have been supplied in `initial_workspace`, and in the
|
||||
# future we may pre-compute loadable terms coming from the same
|
||||
# dataset. In either case, we will already have an entry for this
|
||||
|
||||
+41
-16
@@ -16,7 +16,7 @@ class CyclicDependency(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TermGraph(DiGraph):
|
||||
class TermGraph(object):
|
||||
"""
|
||||
An abstract representation of Pipeline Term dependencies.
|
||||
|
||||
@@ -44,7 +44,7 @@ class TermGraph(DiGraph):
|
||||
ExecutionPlan
|
||||
"""
|
||||
def __init__(self, terms):
|
||||
super(TermGraph, self).__init__()
|
||||
self.graph = DiGraph()
|
||||
|
||||
self._frozen = False
|
||||
parents = set()
|
||||
@@ -54,7 +54,6 @@ class TermGraph(DiGraph):
|
||||
assert not parents
|
||||
|
||||
self._outputs = terms
|
||||
self._ordered = topological_sort(self)
|
||||
|
||||
# Mark that no more terms should be added to the graph.
|
||||
self._frozen = True
|
||||
@@ -79,11 +78,11 @@ class TermGraph(DiGraph):
|
||||
|
||||
parents.add(term)
|
||||
|
||||
self.add_node(term)
|
||||
self.graph.add_node(term)
|
||||
|
||||
for dependency in term.dependencies:
|
||||
self._add_to_graph(dependency, parents)
|
||||
self.add_edge(dependency, term)
|
||||
self.graph.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
@@ -94,15 +93,25 @@ class TermGraph(DiGraph):
|
||||
"""
|
||||
return self._outputs
|
||||
|
||||
def execution_order(self, refcounts):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in ``self`` which
|
||||
need to be computed.
|
||||
"""
|
||||
return iter(topological_sort(
|
||||
self.graph.subgraph(
|
||||
{term for term, refcount in refcounts.items() if refcount > 0},
|
||||
),
|
||||
))
|
||||
|
||||
def ordered(self):
|
||||
"""
|
||||
Return a topologically-sorted iterator over the terms in `self`.
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
return iter(topological_sort(self.graph))
|
||||
|
||||
@lazyval
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
return tuple(
|
||||
term for term in self.graph if isinstance(term, LoadableTerm)
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def jpeg(self):
|
||||
@@ -132,15 +141,28 @@ class TermGraph(DiGraph):
|
||||
nodes get one extra reference to ensure that they're still in the graph
|
||||
at the end of execution.
|
||||
"""
|
||||
refcounts = self.out_degree()
|
||||
refcounts = self.graph.out_degree()
|
||||
for t in self.outputs.values():
|
||||
refcounts[t] += 1
|
||||
|
||||
for t in initial_terms:
|
||||
self.decref_dependencies(t, refcounts)
|
||||
self._decref_recursive(t, refcounts, set())
|
||||
|
||||
return refcounts
|
||||
|
||||
def _decref_recursive(self, term, refcounts, garbage):
|
||||
"""
|
||||
Decrement terms recursivly to build the initial workspace.
|
||||
"""
|
||||
# Edges are tuple of (from, to).
|
||||
for parent, _ in self.graph.in_edges([term]):
|
||||
refcounts[parent] -= 1
|
||||
# No one else depends on this term. Remove it from the
|
||||
# workspace to conserve memory.
|
||||
if refcounts[parent] == 0:
|
||||
garbage.add(parent)
|
||||
self._decref_recursive(parent, refcounts, garbage)
|
||||
|
||||
def decref_dependencies(self, term, refcounts):
|
||||
"""
|
||||
Decrement in-edges for ``term`` after computation.
|
||||
@@ -159,7 +181,7 @@ class TermGraph(DiGraph):
|
||||
"""
|
||||
garbage = set()
|
||||
# Edges are tuple of (from, to).
|
||||
for parent, _ in self.in_edges([term]):
|
||||
for parent, _ in self.graph.in_edges([term]):
|
||||
refcounts[parent] -= 1
|
||||
# No one else depends on this term. Remove it from the
|
||||
# workspace to conserve memory.
|
||||
@@ -167,6 +189,9 @@ class TermGraph(DiGraph):
|
||||
garbage.add(parent)
|
||||
return garbage
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.graph)
|
||||
|
||||
|
||||
class ExecutionPlan(TermGraph):
|
||||
"""
|
||||
@@ -326,7 +351,7 @@ class ExecutionPlan(TermGraph):
|
||||
# How much bigger is the array for ``dep`` compared to ``term``?
|
||||
# How much of that difference did I ask for.
|
||||
(term, dep): (extra[dep] - extra[term]) - requested_extra_rows
|
||||
for term in self
|
||||
for term in self.graph
|
||||
for dep, requested_extra_rows in term.dependencies.items()
|
||||
}
|
||||
|
||||
@@ -366,14 +391,14 @@ class ExecutionPlan(TermGraph):
|
||||
"""
|
||||
return {
|
||||
term: attrs['extra_rows']
|
||||
for term, attrs in iteritems(self.node)
|
||||
for term, attrs in iteritems(self.graph.node)
|
||||
}
|
||||
|
||||
def _ensure_extra_rows(self, term, N):
|
||||
"""
|
||||
Ensure that we're going to compute at least N extra rows of `term`.
|
||||
"""
|
||||
attrs = self.node[term]
|
||||
attrs = self.graph.node[term]
|
||||
attrs['extra_rows'] = max(N, attrs.get('extra_rows', 0))
|
||||
|
||||
def mask_and_dates_for_term(self,
|
||||
|
||||
@@ -115,13 +115,14 @@ def _render(g, out, format_, include_asset_exists=False):
|
||||
add_term_node(f, term)
|
||||
|
||||
# Write intermediate results.
|
||||
for term in filter_nodes(include_asset_exists, topological_sort(g)):
|
||||
for term in filter_nodes(include_asset_exists,
|
||||
topological_sort(g.graph)):
|
||||
if term in in_nodes or term in out_nodes:
|
||||
continue
|
||||
add_term_node(f, term)
|
||||
|
||||
# Write edges
|
||||
for source, dest in g.edges():
|
||||
for source, dest in g.graph.edges():
|
||||
if source is AssetExists() and not include_asset_exists:
|
||||
continue
|
||||
add_edge(f, id(source), id(dest))
|
||||
|
||||
Reference in New Issue
Block a user