From 3d18a2954f2db072f6203040cdce5157ad7f7a71 Mon Sep 17 00:00:00 2001 From: warren-oneill Date: Wed, 10 Jun 2015 13:41:48 +0200 Subject: [PATCH] TST:adds algo unittest for ClOSE_POSITON event type, adds commission as parameter to TestAlgorithm --- tests/test_algorithm.py | 48 ++++++++++++++++++++++++++++++++++++++ zipline/test_algorithms.py | 6 ++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2fee8382..67c7c61b 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -44,6 +44,7 @@ from zipline.test_algorithms import ( EmptyPositionsAlgorithm, InvalidOrderAlgorithm, RecordAlgorithm, + TestAlgorithm, TestOrderAlgorithm, TestOrderInstantAlgorithm, TestOrderPercentAlgorithm, @@ -90,6 +91,10 @@ from zipline.finance.trading import SimulationParameters from zipline.utils.api_support import set_algo_instance from zipline.utils.events import DateRuleFactory, TimeRuleFactory from zipline.algorithm import TradingAlgorithm +from zipline.finance import trading +from zipline.protocol import DATASOURCE_TYPE +from zipline.finance.trading import TradingEnvironment +from zipline.finance.commission import PerShare class TestRecordAlgorithm(TestCase): @@ -1178,3 +1183,46 @@ class TestAccountControls(TestCase): algo = SetMaxLeverageAlgorithm(1) self.check_algo_succeeds(algo, handle_data) + + +class TestClosePosAlgo(TestCase): + + def setUp(self): + days = TradingEnvironment().trading_days + self.index = [days[0], days[1], days[2]] + pan = pd.Panel({1: pd.DataFrame({ + 'price': [1, 2, 4], 'volume': [1e9, 0, 0], + 'type': [DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.CLOSE_POSITION]}, + index=self.index) + }) + + self.data = DataPanelSource(pan) + self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, + instant_fill=True, commission=PerShare(0)) + self.results = self.run_algo() + self.expected_positions = [1, 1, 0] + self.expected_pnl = [0, 1, 2] + + def run_algo(self): + results = self.algo.run(self.data) + return results + + def test_algo_pnl(self): + for i, pnl in enumerate(self.results.pnl): + self.assertEqual(pnl, self.expected_pnl[i]) + + def test_algo_positions(self): + for i, amount in enumerate(self.results.positions): + if amount: + actual_position = amount[0]['amount'] + else: + actual_position = 0 + + self.assertEqual(actual_position, self.expected_positions[i]) + + def tearDown(self): + pass + self.algo = None + trading.environment = None diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 83715ccd..3d99f30b 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -108,7 +108,8 @@ class TestAlgorithm(TradingAlgorithm): amount, order_count, sid_filter=None, - slippage=None): + slippage=None, + commission=None): self.count = order_count self.sid = sid self.amount = amount @@ -122,6 +123,9 @@ class TestAlgorithm(TradingAlgorithm): if slippage is not None: self.set_slippage(slippage) + if commission is not None: + self.set_commission(commission) + def handle_data(self, data): # place an order for amount shares of sid if self.incr < self.count: