Files
catalyst/zipline/modelling/visualize.py
T
Scott Sanderson 670706158a ENH: Don't write files when rendering TermGraphs.
Overhauls `zipline.modelling.visualize` to use in-memory buffers when
shelling out to `dot` and friends.

Also adds `svg`, `png`, and `jpeg` attributes to `TermGraph`, and adds a
`_repr_png_` so that `TermGraph` renders as a PNG by default.
2015-10-01 18:03:53 -04:00

227 lines
5.7 KiB
Python

"""
Tools for visualizing dependencies between Terms.
"""
from __future__ import unicode_literals
from contextlib import contextmanager
import errno
from functools import partial
from io import BytesIO
from subprocess import Popen, PIPE
from networkx import topological_sort
from six import iteritems
from zipline.data.dataset import BoundColumn
from zipline.modelling import Filter, Factor, Classifier, Term
from zipline.modelling.term import AssetExists
class NoIPython(Exception):
pass
def delimit(delimiters, content):
"""
Surround `content` with the first and last characters of `delimiters`.
>>> delimit('[]', "foo")
[foo]
>>> delimit('""', "foo")
'"foo"'
"""
if len(delimiters) != 2:
raise ValueError(
"`delimiters` must be of length 2. Got %r" % delimiters
)
return ''.join([delimiters[0], content, delimiters[1]])
quote = partial(delimit, '""')
bracket = partial(delimit, '[]')
def begin_graph(f, name, **attrs):
writeln(f, "strict digraph %s {" % name)
writeln(f, "graph {}".format(format_attrs(attrs)))
def begin_cluster(f, name, **attrs):
attrs.setdefault("label", quote(name))
writeln(f, "subgraph cluster_%s {" % name)
writeln(f, "graph {}".format(format_attrs(attrs)))
def end_graph(f):
writeln(f, '}')
@contextmanager
def graph(f, name, **attrs):
begin_graph(f, name, **attrs)
yield
end_graph(f)
@contextmanager
def cluster(f, name, **attrs):
begin_cluster(f, name, **attrs)
yield
end_graph(f)
def roots(g):
"Get nodes from graph G with indegree 0"
return set(n for n, d in iteritems(g.in_degree()) if d == 0)
def _render(g, out, format_, include_asset_exists=False):
"""
Draw `g` as a graph to `out`, in format `format`.
Parameters
----------
g : zipline.modelling.graph.TermGraph
Graph to render.
out : file-like object
format_ : str {'png', 'svg'}
Output format.
include_asset_exists : bool
Whether to filter out `AssetExists()` nodes.
"""
graph_attrs = {'rankdir': 'TB', 'splines': 'ortho'}
cluster_attrs = {'style': 'filled', 'color': 'lightgoldenrod1'}
in_nodes = list(node for node in g if node.atomic)
out_nodes = list(g.outputs.values())
f = BytesIO()
with graph(f, "G", **graph_attrs):
# Write outputs cluster.
with cluster(f, 'Output', labelloc='b', **cluster_attrs):
for term in out_nodes:
add_term_node(f, term)
# Write inputs cluster.
with cluster(f, 'Input', **cluster_attrs):
for term in in_nodes:
if term is AssetExists() and not include_asset_exists:
continue
add_term_node(f, term)
# Write intermediate results.
for term in topological_sort(g):
if term in in_nodes or term in out_nodes:
continue
add_term_node(f, term)
# Write edges
for source, dest in g.edges():
if source is AssetExists() and not include_asset_exists:
continue
add_edge(f, id(source), id(dest))
cmd = ['dot', '-T', format_]
try:
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE)
except OSError as e:
if e.errno == errno.ENOENT:
raise RuntimeError(
"Couldn't find `dot` graph layout program. "
"Make sure Graphviz is installed and `dot` is on your path."
)
else:
raise
f.seek(0)
proc_stdout, proc_stderr = proc.communicate(f.read())
if proc_stderr:
raise RuntimeError(
"Error(s) while rendering graph: %s" % proc_stderr.decode('utf-8')
)
out.write(proc_stdout)
def display_graph(g, format='svg', include_asset_exists=False):
"""
Display a TermGraph interactively from within IPython.
"""
try:
import IPython.display as display
except ImportError:
raise NoIPython("IPython is not installed. Can't display graph.")
if format == 'svg':
display_cls = display.SVG
elif format in ("jpeg", "png"):
display_cls = partial(display.Image, format=format, embed=True)
out = BytesIO()
_render(g, out, format, include_asset_exists=include_asset_exists)
return display_cls(data=out.getvalue())
def writeln(f, s):
f.write((s + '\n').encode('utf-8'))
def fmt(obj):
if isinstance(obj, Term):
if hasattr(obj, 'short_repr'):
r = obj.short_repr()
else:
r = type(obj).__name__
else:
r = obj
return '"%s"' % r
def add_term_node(f, term):
declare_node(f, id(term), attrs_for_node(term))
def declare_node(f, name, attributes):
writeln(f, "{0} {1};".format(name, format_attrs(attributes)))
def add_edge(f, source, dest):
writeln(f, "{0} -> {1};".format(source, dest))
def attrs_for_node(term, **overrides):
attrs = {
'shape': 'box',
'colorscheme': 'pastel19',
'style': 'filled',
'label': fmt(term),
}
if isinstance(term, BoundColumn):
attrs['fillcolor'] = '1'
if isinstance(term, Factor):
attrs['fillcolor'] = '2'
elif isinstance(term, Filter):
attrs['fillcolor'] = '3'
elif isinstance(term, Classifier):
attrs['fillcolor'] = '4'
attrs.update(**overrides or {})
return attrs
def format_attrs(attrs):
"""
Format key, value pairs from attrs into graphviz attrs format
Example
-------
>>> format_attrs({'key1': 'value1', 'key2': 'value2'})
'[key1=value1, key2=value2]'
"""
if not attrs:
return ''
entries = ['='.join((key, value)) for key, value in iteritems(attrs)]
return '[' + ', '.join(entries) + ']'