BUG: Fix broken graph visualizations.

This commit is contained in:
Scott Sanderson
2016-08-18 11:07:17 -04:00
parent c5a3dae267
commit bdc72ec4c0
2 changed files with 70 additions and 3 deletions
+58
View File
@@ -1,10 +1,14 @@
"""
Tests for zipline.pipeline.Pipeline
"""
import inspect
from unittest import TestCase
from mock import patch
from zipline.pipeline import Factor, Filter, Pipeline
from zipline.pipeline.data import USEquityPricing
from zipline.pipeline.graph import display_graph
from zipline.utils.numpy_utils import float64_dtype
@@ -137,3 +141,57 @@ class PipelineTestCase(TestCase):
"expected a value of type bool or int for argument 'overwrite'",
message,
)
def test_show_graph(self):
f = SomeFactor()
p = Pipeline(columns={'f': SomeFactor()})
# The real display_graph call shells out to GraphViz, which isn't a
# requirement, so patch it out for testing.
def mock_display_graph(g, format='svg', include_asset_exists=False):
return (g, format, include_asset_exists)
self.assertEqual(
inspect.getargspec(display_graph),
inspect.getargspec(mock_display_graph),
msg="Mock signature doesn't match signature for display_graph."
)
patch_display_graph = patch(
'zipline.pipeline.graph.display_graph',
mock_display_graph,
)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph()
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'svg')
self.assertEqual(include_asset_exists, False)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph(format='png')
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'png')
self.assertEqual(include_asset_exists, False)
with patch_display_graph:
graph, format, include_asset_exists = p.show_graph(format='jpeg')
self.assertIs(graph.outputs['f'], f)
# '' is a sentinel used for screen if it's not supplied.
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
self.assertEqual(format, 'jpeg')
self.assertEqual(include_asset_exists, False)
expected = (
r".*\.show_graph\(\) expected a value in "
r"\('svg', 'png', 'jpeg'\) for argument 'format', "
r"but got 'fizzbuzz' instead."
)
with self.assertRaisesRegexp(ValueError, expected):
p.show_graph(format='fizzbuzz')
+12 -3
View File
@@ -1,6 +1,10 @@
from zipline.errors import UnsupportedPipelineOutput
from zipline.utils.input_validation import expect_types, optional
from zipline.utils.input_validation import (
expect_element,
expect_types,
optional,
)
from .graph import ExecutionPlan, TermGraph
from .filters import Filter
@@ -189,7 +193,9 @@ class Pipeline(object):
default_screen : zipline.pipeline.term.Term
Term to use as a screen if self.screen is None.
"""
return TermGraph(self._prepare_graph_terms())
return TermGraph(
self._prepare_graph_terms(screen_name, default_screen)
)
def _prepare_graph_terms(self, screen_name, default_screen):
"""Helper for to_graph and to_execution_plan."""
@@ -200,6 +206,7 @@ class Pipeline(object):
columns[screen_name] = screen
return columns
@expect_element(format=('svg', 'png', 'jpeg'))
def show_graph(self, format='svg'):
"""
Render this Pipeline as a DAG.
@@ -217,7 +224,9 @@ class Pipeline(object):
elif format == 'jpeg':
return g.jpeg
else:
raise ValueError("Unknown graph format %r." % format)
# We should never get here because of the expect_element decorator
# above.
raise AssertionError("Unknown graph format %r." % format)
@staticmethod
def validate_column(column_name, term):