diff --git a/zipline/pipeline/visualize.py b/zipline/pipeline/visualize.py index 44ffe530..c9697edf 100644 --- a/zipline/pipeline/visualize.py +++ b/zipline/pipeline/visualize.py @@ -75,6 +75,12 @@ def roots(g): return set(n for n, d in iteritems(g.in_degree()) if d == 0) +def filter_nodes(include_asset_exists, nodes): + if include_asset_exists: + return nodes + return filter(lambda n: n is not AssetExists(), nodes) + + def _render(g, out, format_, include_asset_exists=False): """ Draw `g` as a graph to `out`, in format `format`. @@ -100,18 +106,16 @@ def _render(g, out, format_, include_asset_exists=False): # Write outputs cluster. with cluster(f, 'Output', labelloc='b', **cluster_attrs): - for term in out_nodes: + for term in filter_nodes(include_asset_exists, out_nodes): add_term_node(f, term) # Write inputs cluster. with cluster(f, 'Input', **cluster_attrs): - for term in in_nodes: - if term is AssetExists() and not include_asset_exists: - continue + for term in filter_nodes(include_asset_exists, in_nodes): add_term_node(f, term) # Write intermediate results. - for term in topological_sort(g): + for term in filter_nodes(include_asset_exists, topological_sort(g)): if term in in_nodes or term in out_nodes: continue add_term_node(f, term)