diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index e8d67335..85715f50 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -48,6 +48,7 @@ from zipline.test_algorithms import ( EmptyPositionsAlgorithm, InvalidOrderAlgorithm, RecordAlgorithm, + FutureFlipAlgo, TestAlgorithm, TestOrderAlgorithm, TestOrderInstantAlgorithm, @@ -1842,3 +1843,52 @@ class TestClosePosAlgo(TestCase): actual_position, expected_positions[i], "position for day={0} not equal, actual={1}, expected={2}". format(i, actual_position, expected_positions[i])) + + +class TestFutureFlip(TestCase): + def setUp(self): + self.env = TradingEnvironment() + self.days = self.env.trading_days[:4] + self.trades_panel = pd.Panel({1: pd.DataFrame({ + 'price': [1, 2, 4], 'volume': [1e9, 1e9, 1e9], + 'type': [DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.TRADE, + DATASOURCE_TYPE.TRADE]}, + index=self.days[:3]) + }) + + def test_flip_algo(self): + metadata = {1: {'symbol': 'TEST', + 'asset_type': 'equity', + 'end_date': self.days[3], + 'contract_multiplier': 5}} + self.env.write_data(futures_data=metadata) + + algo = FutureFlipAlgo(sid=1, amount=1, env=self.env, + commission=PerShare(0), + order_count=0, # not applicable but required + instant_fill=True) + data = DataPanelSource(self.trades_panel) + + results = algo.run(data) + + expected_positions = [1, -1, 0] + self.check_algo_positions(results, expected_positions) + + expected_pnl = [0, 5, -10] + self.check_algo_pnl(results, expected_pnl) + + def check_algo_pnl(self, results, expected_pnl): + np.testing.assert_array_almost_equal(results.pnl, expected_pnl) + + def check_algo_positions(self, results, expected_positions): + for i, amount in enumerate(results.positions): + if amount: + actual_position = amount[0]['amount'] + else: + actual_position = 0 + + self.assertEqual( + actual_position, expected_positions[i], + "position for day={0} not equal, actual={1}, expected={2}". + format(i, actual_position, expected_positions[i])) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 496c3739..a56d001f 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -442,6 +442,16 @@ class TestTargetValueAlgorithm(TradingAlgorithm): 20 / (data[0].price * self.sid(0).contract_multiplier)) +class FutureFlipAlgo(TestAlgorithm): + def handle_data(self, data): + if len(self.portfolio.positions) > 0: + if self.portfolio.positions[self.asset.sid]["amount"] > 0: + self.order_target(self.asset, -self.amount) + else: + self.order_target(self.asset, 0) + else: + self.order_target(self.asset, self.amount) + ############################ # AccountControl Test Algos# ############################