mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 03:48:58 +08:00
BUG: Fix broken graph visualizations.
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
"""
|
||||
Tests for zipline.pipeline.Pipeline
|
||||
"""
|
||||
import inspect
|
||||
from unittest import TestCase
|
||||
|
||||
from mock import patch
|
||||
|
||||
from zipline.pipeline import Factor, Filter, Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.graph import display_graph
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
|
||||
@@ -137,3 +141,57 @@ class PipelineTestCase(TestCase):
|
||||
"expected a value of type bool or int for argument 'overwrite'",
|
||||
message,
|
||||
)
|
||||
|
||||
def test_show_graph(self):
|
||||
f = SomeFactor()
|
||||
p = Pipeline(columns={'f': SomeFactor()})
|
||||
|
||||
# The real display_graph call shells out to GraphViz, which isn't a
|
||||
# requirement, so patch it out for testing.
|
||||
|
||||
def mock_display_graph(g, format='svg', include_asset_exists=False):
|
||||
return (g, format, include_asset_exists)
|
||||
|
||||
self.assertEqual(
|
||||
inspect.getargspec(display_graph),
|
||||
inspect.getargspec(mock_display_graph),
|
||||
msg="Mock signature doesn't match signature for display_graph."
|
||||
)
|
||||
|
||||
patch_display_graph = patch(
|
||||
'zipline.pipeline.graph.display_graph',
|
||||
mock_display_graph,
|
||||
)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph()
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'svg')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph(format='png')
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'png')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
with patch_display_graph:
|
||||
graph, format, include_asset_exists = p.show_graph(format='jpeg')
|
||||
self.assertIs(graph.outputs['f'], f)
|
||||
# '' is a sentinel used for screen if it's not supplied.
|
||||
self.assertEqual(sorted(graph.outputs.keys()), ['', 'f'])
|
||||
self.assertEqual(format, 'jpeg')
|
||||
self.assertEqual(include_asset_exists, False)
|
||||
|
||||
expected = (
|
||||
r".*\.show_graph\(\) expected a value in "
|
||||
r"\('svg', 'png', 'jpeg'\) for argument 'format', "
|
||||
r"but got 'fizzbuzz' instead."
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected):
|
||||
p.show_graph(format='fizzbuzz')
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
|
||||
from zipline.errors import UnsupportedPipelineOutput
|
||||
from zipline.utils.input_validation import expect_types, optional
|
||||
from zipline.utils.input_validation import (
|
||||
expect_element,
|
||||
expect_types,
|
||||
optional,
|
||||
)
|
||||
|
||||
from .graph import ExecutionPlan, TermGraph
|
||||
from .filters import Filter
|
||||
@@ -189,7 +193,9 @@ class Pipeline(object):
|
||||
default_screen : zipline.pipeline.term.Term
|
||||
Term to use as a screen if self.screen is None.
|
||||
"""
|
||||
return TermGraph(self._prepare_graph_terms())
|
||||
return TermGraph(
|
||||
self._prepare_graph_terms(screen_name, default_screen)
|
||||
)
|
||||
|
||||
def _prepare_graph_terms(self, screen_name, default_screen):
|
||||
"""Helper for to_graph and to_execution_plan."""
|
||||
@@ -200,6 +206,7 @@ class Pipeline(object):
|
||||
columns[screen_name] = screen
|
||||
return columns
|
||||
|
||||
@expect_element(format=('svg', 'png', 'jpeg'))
|
||||
def show_graph(self, format='svg'):
|
||||
"""
|
||||
Render this Pipeline as a DAG.
|
||||
@@ -217,7 +224,9 @@ class Pipeline(object):
|
||||
elif format == 'jpeg':
|
||||
return g.jpeg
|
||||
else:
|
||||
raise ValueError("Unknown graph format %r." % format)
|
||||
# We should never get here because of the expect_element decorator
|
||||
# above.
|
||||
raise AssertionError("Unknown graph format %r." % format)
|
||||
|
||||
@staticmethod
|
||||
def validate_column(column_name, term):
|
||||
|
||||
Reference in New Issue
Block a user