From 244664b6a38de2d3b4c7ef2f3c3d4d58ee3f9ce7 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Tue, 24 May 2016 21:56:39 -0400 Subject: [PATCH] MAINT: Clean up default handling in TradingAlgorithm. --- tests/test_algorithm.py | 30 ++++++++++++++++++++++++++ zipline/algorithm.py | 48 ++++++++++++++++++++++++++--------------- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 07b58831..7eec1176 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -3393,3 +3393,33 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase): "asset will be liquidated on " "2016-01-11 00:00:00+00:00.", w.message) + + +class AlgoInputValidationTestCase(ZiplineTestCase): + + def test_reject_passing_both_api_methods_and_script(self): + script = dedent( + """ + def initialize(context): + pass + + def handle_data(context, data): + pass + + def before_trading_start(context, data): + pass + + def analyze(context, results): + pass + """ + ) + for method in ('initialize', + 'handle_data', + 'before_trading_start', + 'analyze'): + + with self.assertRaises(ValueError): + TradingAlgorithm( + script=script, + **{method: lambda *args, **kwargs: None} + ) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 2e6fc3c2..5e6b5438 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -31,6 +31,7 @@ from six import ( iteritems, itervalues, string_types, + viewkeys, ) from zipline._protocol import handle_non_market_minutes @@ -332,29 +333,46 @@ class TradingAlgorithm(object): self._handle_data = None + def noop(*args, **kwargs): + pass + if self.algoscript is not None: + api_methods = { + 'initialize', + 'handle_data', + 'before_trading_start', + 'analyze', + } + unexpected_api_methods = viewkeys(kwargs) & api_methods + if unexpected_api_methods: + raise ValueError( + "TradingAlgorithm received a script and the following API" + " methods as functions:\n{funcs}".format( + funcs=unexpected_api_methods, + ) + ) + filename = kwargs.pop('algo_filename', None) if filename is None: filename = '' code = compile(self.algoscript, filename, 'exec') exec_(code, self.namespace) - self._initialize = self.namespace.get('initialize') - if 'handle_data' in self.namespace: - self._handle_data = self.namespace['handle_data'] - self._before_trading_start = \ - self.namespace.get('before_trading_start') + self._initialize = self.namespace.get('initialize', noop) + self._handle_data = self.namespace.get('handle_data', noop) + self._before_trading_start = self.namespace.get( + 'before_trading_start', + ) # Optional analyze function, gets called after run self._analyze = self.namespace.get('analyze') - elif kwargs.get('initialize') and kwargs.get('handle_data'): - if self.algoscript is not None: - raise ValueError('You can not set script and \ - initialize/handle_data.') - self._initialize = kwargs.pop('initialize') - self._handle_data = kwargs.pop('handle_data') - self._before_trading_start = kwargs.pop('before_trading_start', - None) + else: + self._initialize = kwargs.pop('initialize', noop) + self._handle_data = kwargs.pop('handle_data', noop) + self._before_trading_start = kwargs.pop( + 'before_trading_start', + None, + ) self._analyze = kwargs.pop('analyze', None) self.event_manager.add_event( @@ -367,10 +385,6 @@ class TradingAlgorithm(object): prepend=True, ) - # If method not defined, NOOP - if self._initialize is None: - self._initialize = lambda x: None - # Alternative way of setting data_frequency for backwards # compatibility. if 'data_frequency' in kwargs: