diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index 9054c57f..65d200bd 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -117,7 +117,7 @@ class ConstantInputTestCase(TestCase): loader = self.loader engine = SimplePipelineEngine(loader, self.dates, self.asset_finder) - p = Pipeline('test') + p = Pipeline() msg = "start_date must be before end_date .*" with self.assertRaisesRegexp(ValueError, msg): @@ -135,7 +135,7 @@ class ConstantInputTestCase(TestCase): factor = AssetID() for asset in assets: - p = Pipeline('test', columns={'f': factor}, screen=factor <= asset) + p = Pipeline(columns={'f': factor}, screen=factor <= asset) result = engine.run_pipeline(p, dates[0], dates[-1]) expected_sids = assets[assets <= asset] @@ -161,9 +161,8 @@ class ConstantInputTestCase(TestCase): # Since every asset will pass the screen, these should be equivalent. pipelines = [ - Pipeline('test', columns={'f': factor}), + Pipeline(columns={'f': factor}), Pipeline( - 'test', columns={'f': factor}, screen=factor.eq(expected_result), ), @@ -198,7 +197,6 @@ class ConstantInputTestCase(TestCase): ) pipeline = Pipeline( - 'test', columns={ 'short': short_factor, 'long': long_factor, @@ -242,7 +240,6 @@ class ConstantInputTestCase(TestCase): results = engine.run_pipeline( Pipeline( - 'test', columns={ 'high_low': high_minus_low, 'open_close': open_minus_close, @@ -373,7 +370,6 @@ class FrameInputTestCase(TestCase): for start, stop in bounds: results = engine.run_pipeline( Pipeline( - 'test', columns={'low': low_mavg, 'high': high_mavg} ), dates[start], @@ -488,7 +484,7 @@ class SyntheticBcolzTestCase(TestCase): ) results = engine.run_pipeline( - Pipeline('test', columns={'sma': SMA}), + Pipeline(columns={'sma': SMA}), dates_to_test[0], dates_to_test[-1], ) @@ -540,7 +536,7 @@ class SyntheticBcolzTestCase(TestCase): ) results = engine.run_pipeline( - Pipeline('test', columns={'drawdown': drawdown}), + Pipeline(columns={'drawdown': drawdown}), dates_to_test[0], dates_to_test[-1], ) @@ -594,7 +590,6 @@ class MultiColumnLoaderTestCase(TestCase): result = engine.run_pipeline( Pipeline( - 'test', columns={ 'sumdiff': sumdiff, 'open': open_.latest, diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index a2b752be..a57a312a 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -30,21 +30,20 @@ class SomeOtherFilter(Filter): class PipelineTestCase(TestCase): def test_construction(self): - p0 = Pipeline('arglebargle') - self.assertEqual(p0.name, 'arglebargle') + p0 = Pipeline() self.assertEqual(p0.columns, {}) self.assertIs(p0.screen, None) columns = {'f': SomeFactor()} - p1 = Pipeline('test', columns=columns) + p1 = Pipeline(columns=columns) self.assertEqual(p1.columns, columns) screen = SomeFilter() - p2 = Pipeline('test', screen=screen) + p2 = Pipeline(screen=screen) self.assertEqual(p2.columns, {}) self.assertEqual(p2.screen, screen) - p3 = Pipeline('test', columns=columns, screen=screen) + p3 = Pipeline(columns=columns, screen=screen) self.assertEqual(p3.columns, columns) self.assertEqual(p3.screen, screen) @@ -53,21 +52,18 @@ class PipelineTestCase(TestCase): with self.assertRaises(TypeError): Pipeline(1) - with self.assertRaises(TypeError): - Pipeline('test', 1) - - Pipeline('test', {}) + Pipeline({}) with self.assertRaises(TypeError): - Pipeline('test', {}, 1) + Pipeline({}, 1) with self.assertRaises(TypeError): - Pipeline('test', {}, SomeFactor()) + Pipeline({}, SomeFactor()) - Pipeline('test', {}, SomeFactor() > 5) + Pipeline({}, SomeFactor() > 5) def test_add(self): - p = Pipeline('test') + p = Pipeline() f = SomeFactor() p.add(f, 'f') @@ -80,7 +76,7 @@ class PipelineTestCase(TestCase): p.add(f, 1) def test_overwrite(self): - p = Pipeline('test') + p = Pipeline() f = SomeFactor() other_f = SomeOtherFactor() @@ -97,7 +93,7 @@ class PipelineTestCase(TestCase): def test_remove(self): f = SomeFactor() - p = Pipeline('test', columns={'f': f}) + p = Pipeline(columns={'f': f}) with self.assertRaises(KeyError) as e: p.remove('not_a_real_name') @@ -112,7 +108,7 @@ class PipelineTestCase(TestCase): def test_set_screen(self): f, g = SomeFilter(), SomeOtherFilter() - p = Pipeline('test') + p = Pipeline() self.assertEqual(p.screen, None) p.set_screen(f) diff --git a/tests/pipeline/test_pipeline_algo.py b/tests/pipeline/test_pipeline_algo.py index 7de4b9e7..9742f40a 100644 --- a/tests/pipeline/test_pipeline_algo.py +++ b/tests/pipeline/test_pipeline_algo.py @@ -173,7 +173,7 @@ class ClosesOnly(TestCase): pass def late_attach(context, data): - attach_pipeline(Pipeline('test')) + attach_pipeline(Pipeline(), 'test') raise AssertionError("Shouldn't make it past attach_pipeline!") algo = TradingAlgorithm( @@ -211,7 +211,7 @@ class ClosesOnly(TestCase): Assert that calling pipeline_output after initialize raises correctly. """ def initialize(context): - attach_pipeline(Pipeline('test')) + attach_pipeline(Pipeline(), 'test') pipeline_output('test') raise AssertionError("Shouldn't make it past pipeline_output()") @@ -240,7 +240,7 @@ class ClosesOnly(TestCase): Assert that calling add_pipeline after initialize raises appropriately. """ def initialize(context): - attach_pipeline(Pipeline('test')) + attach_pipeline(Pipeline(), 'test') def handle_data(context, data): raise AssertionError("Shouldn't make it past before_trading_start") @@ -269,11 +269,9 @@ class ClosesOnly(TestCase): correctly-adjusted close price values. """ def initialize(context): - p = Pipeline('test') + p = attach_pipeline(Pipeline(), 'test') p.add(USEquityPricing.close.latest, 'close') - attach_pipeline(p) - def handle_data(context, data): results = pipeline_output('test') date = get_datetime().normalize() @@ -477,7 +475,7 @@ class PipelineAlgorithmTestCase(TestCase): return "vwap_%d" % length def initialize(context): - pipeline = Pipeline('test') + pipeline = Pipeline() context.vwaps = [] for length in vwaps: name = vwap_key(length) @@ -490,7 +488,7 @@ class PipelineAlgorithmTestCase(TestCase): if set_screen: pipeline.set_screen(filter_) - attach_pipeline(pipeline) + attach_pipeline(pipeline, 'test') def handle_data(context, data): today = get_datetime() diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 3728874f..05fda7ab 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -231,7 +231,7 @@ class TradingAlgorithm(object): # Initialize Pipeline API data. self.init_engine(kwargs.pop('pipeline_loader', None)) - self._pipelines = [] + self._pipelines = {} # Create an always-expired cache so that we compute the first time data # is requested. self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC')) @@ -1337,13 +1337,17 @@ class TradingAlgorithm(object): ############## @api_method @require_not_initialized(AttachPipelineAfterInitialize()) - def attach_pipeline(self, pipeline): + def attach_pipeline(self, pipeline, name): """ Register a pipeline to be computed at the start of each day. """ if self._pipelines: raise NotImplementedError("Multiple pipelines are not supported.") - self._pipelines.append(pipeline) + self._pipelines[name] = pipeline + + # Return the pipeline to allow expressions like + # p = attach_pipeline(Pipeline(), 'name') + return pipeline @api_method @require_initialized(PipelineOutputDuringInitialize()) @@ -1353,7 +1357,7 @@ class TradingAlgorithm(object): Parameters ---------- - name : str or None + name : str Name of the pipeline for which results are requested. Returns @@ -1373,14 +1377,12 @@ class TradingAlgorithm(object): """ # NOTE: We don't currently support multiple pipelines, but we plan to # in the future. - for p in self._pipelines: - if p.name == name: - break - # This is a for-else block. Yes, that's a thing in Python. - else: + try: + p = self._pipelines[name] + except KeyError: raise NoSuchPipeline( name=name, - valid=[p.name for p in self._pipelines], + valid=list(self._pipelines.keys()), ) return self._pipeline_output(p) diff --git a/zipline/pipeline/pipeline.py b/zipline/pipeline/pipeline.py index f50333ed..b893230e 100644 --- a/zipline/pipeline/pipeline.py +++ b/zipline/pipeline/pipeline.py @@ -7,10 +7,10 @@ from .graph import TermGraph class Pipeline(object): """ + A computational Pipeline for use in trading algorithms. + Parameters ---------- - name : str, optional - Name for this pipeline. columns : dict, optional Initial columns. screen : zipline.pipeline.term.Filter, optional @@ -20,34 +20,26 @@ class Pipeline(object): ------- add remove - apply_screen + set_screen Attributes ---------- columns screen """ - __slots__ = ('_name', '_columns', '_screen', '__weakref__') + __slots__ = ('_columns', '_screen', '__weakref__') @expect_types( - name=str, columns=optional(dict), screen=optional(Filter), ) - def __init__(self, name, columns=None, screen=None): - self._name = name + def __init__(self, columns=None, screen=None): + if columns is None: columns = {} self._columns = columns self._screen = screen - @property - def name(self): - """ - The name of this pipeline. - """ - return self._name - @property def columns(self): """ @@ -114,10 +106,7 @@ class Pipeline(object): @expect_types(screen=Filter) def set_screen(self, screen, overwrite=False): """ - Apply a screen to this Pipeline. - - If no screen has yet been applied to the pipeline, this method sets - `screen` as the current screen. + Set a screen on this Pipeline. Parameter ---------