diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 4101f48b..71acfe1f 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -96,7 +96,6 @@ from zipline.finance.trading import SimulationParameters from zipline.utils.api_support import set_algo_instance from zipline.utils.events import DateRuleFactory, TimeRuleFactory from zipline.algorithm import TradingAlgorithm -from zipline.finance import trading from zipline.protocol import DATASOURCE_TYPE from zipline.finance.trading import TradingEnvironment from zipline.finance.commission import PerShare @@ -1294,46 +1293,62 @@ class TestAccountControls(TestCase): class TestClosePosAlgo(TestCase): def setUp(self): - days = TradingEnvironment().trading_days - self.index = [days[0], days[1], days[2]] - pan = pd.Panel({1: pd.DataFrame({ + self.days = TradingEnvironment().trading_days + self.index = [self.days[0], self.days[1], self.days[2]] + self.panel = pd.Panel({1: pd.DataFrame({ 'price': [1, 2, 4], 'volume': [1e9, 0, 0], 'type': [DATASOURCE_TYPE.TRADE, DATASOURCE_TYPE.TRADE, DATASOURCE_TYPE.CLOSE_POSITION]}, index=self.index) }) - self.data = DataPanelSource(pan) + def test_close_position_equity(self): metadata = {1: {'symbol': 'TEST', - 'asset_type': 'future', - 'notice_date': days[2], - 'expiration_date': days[3]}} + 'asset_type': 'equity', + 'end_date': self.days[3]}} self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, instant_fill=True, commission=PerShare(0), asset_metadata=metadata) - self.results = self.run_algo() - self.expected_positions = [1, 1, 0] - self.expected_pnl = [0, 1, 2] + self.data = DataPanelSource(self.panel) + + # Check results + expected_positions = [1, 1, 0] + expected_pnl = [0, 1, 2] + results = self.run_algo() + self.check_algo_pnl(results, expected_pnl) + self.check_algo_positions(results, expected_positions) + + def test_close_position_future(self): + metadata = {1: {'symbol': 'TEST', + 'asset_type': 'future', + 'notice_date': self.days[2], + 'expiration_date': self.days[3]}} + self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, + instant_fill=True, commission=PerShare(0), + asset_metadata=metadata) + self.data = DataPanelSource(self.panel) + + # Check results + expected_positions = [1, 1, 0] + expected_pnl = [0, 1, 2] + results = self.run_algo() + self.check_algo_pnl(results, expected_pnl) + self.check_algo_positions(results, expected_positions) def run_algo(self): results = self.algo.run(self.data) return results - def test_algo_pnl(self): - for i, pnl in enumerate(self.results.pnl): - self.assertEqual(pnl, self.expected_pnl[i]) + def check_algo_pnl(self, results, expected_pnl): + for i, pnl in enumerate(results.pnl): + self.assertEqual(pnl, expected_pnl[i]) - def test_algo_positions(self): - for i, amount in enumerate(self.results.positions): + def check_algo_positions(self, results, expected_positions): + for i, amount in enumerate(results.positions): if amount: actual_position = amount[0]['amount'] else: actual_position = 0 - self.assertEqual(actual_position, self.expected_positions[i]) - - def tearDown(self): - pass - self.algo = None - trading.environment = None + self.assertEqual(actual_position, expected_positions[i])