diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2ca03482..18e1c444 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1891,3 +1891,34 @@ class TestFutureFlip(TestCase): actual_position, expected_positions[i], "position for day={0} not equal, actual={1}, expected={2}". format(i, actual_position, expected_positions[i])) + + +class TestTradingAlgorithm(TestCase): + def setUp(self): + self.env = TradingEnvironment() + self.days = self.env.trading_days[:4] + self.panel = pd.Panel({1: pd.DataFrame({ + 'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0], + 'type': [DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.CLOSE_POSITION]}, + index=self.days) + }) + + def test_analyze_called(self): + self.perf_ref = None + + def initialize(context): + pass + + def handle_data(context, data): + pass + + def analyze(context, perf): + self.perf_ref = perf + + algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data, + analyze=analyze) + results = algo.run(self.panel) + self.assertIs(results, self.perf_ref) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 38daf6e7..c2e3a628 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -297,6 +297,7 @@ class TradingAlgorithm(object): self._handle_data = kwargs.pop('handle_data') self._before_trading_start = kwargs.pop('before_trading_start', None) + self._analyze = kwargs.pop('analyze', None) self.event_manager.add_event( zipline.utils.events.Event(