ENH: Add tools for visualizing FFC graphs.

This commit is contained in:
Scott Sanderson
2015-09-02 10:32:10 -04:00
parent ad54eedeea
commit dad3bbd879
4 changed files with 236 additions and 0 deletions
+3
View File
@@ -74,6 +74,9 @@ class BoundColumn(Term):
dtype=self.dtype.__name__,
)
def short_repr(self):
return self.qualname
class DataSetMeta(type):
"""
+13
View File
@@ -0,0 +1,13 @@
from .classifier import Classifier
from .factor import Factor
from .filter import Filter
from .term import Term
from .graph import TermGraph
__all__ = [
'Classifier',
'Factor',
'Filter',
'Term',
'TermGraph',
]
+6
View File
@@ -305,3 +305,9 @@ class NumericalExpression(Term):
expr=self._expr,
bindings=self.bindings,
)
def short_repr(self):
return "Expression: {expr}".format(
typename=type(self).__name__,
expr=self._expr,
)
+214
View File
@@ -0,0 +1,214 @@
"""
Tools for visualizing dependencies between Terms.
"""
from functools import partial
from contextlib import contextmanager
from logbook import Logger, StderrHandler
from networkx import topological_sort
from six import iteritems, itervalues
import subprocess
from zipline.data.dataset import BoundColumn
from zipline.modelling import Filter, Factor, Classifier, Term
from zipline.modelling.graph import TermGraph
logger = Logger('Visualize')
def delimit(delimiters, content):
"""
Surround `content` with the first and last characters of `delimiters`.
>>> delimit('[]', "foo")
[foo]
>>> delimit('""', "foo")
'"foo"'
"""
if len(delimiters) != 2:
raise ValueError(
"`delimiters` must be of length 2. Got %r" % delimiters
)
return ''.join([delimiters[0], content, delimiters[1]])
quote = partial(delimit, '""')
bracket = partial(delimit, '[]')
def begin_graph(f, name, **attrs):
writeln(f, "strict digraph %s {" % name)
writeln(f, "graph {attrs}".format(attrs=format_attrs(attrs)))
def begin_cluster(f, name, **attrs):
attrs.setdefault("label", quote(name))
writeln(f, "subgraph cluster_%s {" % name)
writeln(f, "graph {attrs}".format(attrs=format_attrs(attrs)))
def end_graph(f):
writeln(f, '}')
@contextmanager
def graph(f, name, **attrs):
begin_graph(f, name, **attrs)
yield
end_graph(f)
@contextmanager
def cluster(f, name, **attrs):
begin_cluster(f, name, **attrs)
yield
end_graph(f)
def roots(g):
"Get nodes from graph G with indegree 0"
return set(n for n, d in iteritems(g.in_degree()) if d == 0)
def write_graph(g, filename, formats=('svg',)):
"""
Write the dependency graph of `terms` as a dot graph.
If `png` (default True), write a .png file using the system `dot` program.
If `pdf` (default False), write a .pdf file using the system `dot` program.
"""
dotfile = filename + '.dot'
graph_attrs = {'rankdir': 'TB', 'splines': 'ortho'}
cluster_attrs = {'style': 'filled', 'color': 'lightgoldenrod1'}
in_nodes = list(node for node in g if node.atomic)
out_nodes = list(g.outputs.values())
with open(dotfile, 'w') as f:
with graph(f, "G", **graph_attrs):
# Write outputs cluster.
with cluster(f, 'Output', labelloc='b', **cluster_attrs):
for term in out_nodes:
add_term_node(f, term)
# Write inputs cluster.
with cluster(f, 'Input', **cluster_attrs):
for term in in_nodes:
add_term_node(f, term)
# Write intermediate results.
for term in topological_sort(g):
if term in in_nodes or term in out_nodes:
continue
add_term_node(f, term)
# Write edges
for source, dest in g.edges():
add_edge(f, id(source), id(dest))
outs = []
for format_ in formats:
out = '.'.join([filename, format_])
logger.info('Writing "%s"' % out)
subprocess.call(['dot', '-T', format_, dotfile, '-o', out])
outs.append(out)
return outs
def show_graph(g):
"""
Display a TermGraph interactively with IPython
"""
try:
from IPython.display import SVG
except ImportError:
raise Exception("IPython is not installed. Can't show term graph.")
result = write_graph(g, 'temp', ('svg',))[0]
return SVG(filename=result)
def writeln(f, s):
f.write(s + '\n')
def fmt(obj):
if isinstance(obj, Term):
if hasattr(obj, 'short_repr'):
r = obj.short_repr()
else:
r = type(obj).__name__
else:
r = obj
return '"%s"' % r
def add_term_node(f, term):
declare_node(f, id(term), attrs_for_node(term))
def declare_node(f, name, attributes):
writeln(f, "{0} {1};".format(name, format_attrs(attributes)))
def add_edge(f, source, dest):
writeln(f, "{0} -> {1};".format(source, dest))
def attrs_for_node(term, **overrides):
attrs = {
'shape': 'box',
'colorscheme': 'pastel19',
'style': 'filled',
'label': fmt(term),
}
if isinstance(term, BoundColumn):
attrs['fillcolor'] = '1'
if isinstance(term, Factor):
attrs['fillcolor'] = '2'
elif isinstance(term, Filter):
attrs['fillcolor'] = '3'
elif isinstance(term, Classifier):
attrs['fillcolor'] = '4'
attrs.update(**overrides or {})
return attrs
def format_attrs(attrs):
"""
Format key, value pairs from attrs into graphviz attrs format
Example
-------
>>> format_attrs({'key1': 'value1', 'key2': 'value2'})
'[key1=value1, key2=value2]'
"""
if not attrs:
return ''
entries = ['='.join((key, value)) for key, value in iteritems(attrs)]
return '[' + ', '.join(entries) + ']'
if __name__ == '__main__':
from zipline.modelling.factor.technical import VWAP, MaxDrawdown, RSI
from zipline.data.equities import USEquityPricing
with StderrHandler():
vwap = VWAP(window_length=5)
dd_rank = MaxDrawdown([USEquityPricing.close], window_length=5).rank()
rsi_rank = (-RSI()).rank()
score = ((dd_rank + rsi_rank) / 2)
write_graph(
TermGraph(
{
'vwap_pct': vwap.percentile_between(0, 20),
'score': score,
'score_gt': score > 10.0,
},
),
'new',
('png',),
)