diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index f0ac8af6..94acc3d0 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -2056,6 +2056,41 @@ class TestPositionTracker(WithTradingEnvironment, self.assertEqual(100 + 200 + 300000 + 400000, pos_stats.gross_exposure) self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure) + def test_cost_basis(self): + dt = pd.Timestamp("2015-12-10 15:00", tz='UTC') + + equity_pos = perf.Position( + self.EQUITY1, + amount=10, + last_sale_date=dt, + cost_basis=10, + last_sale_price=11, + ) + + future_pos = perf.Position( + self.FUTURE3, + amount=10, + last_sale_date=dt, + cost_basis=10, + last_sale_price=11, + ) + + self.assertEqual(10, equity_pos.cost_basis) + + # send a $5 commission to the equity position. Spread out over 10 + # shares, that bumps the cost basis by $0.50. + equity_pos.adjust_commission_cost_basis(self.EQUITY1, 5) + self.assertEqual(10.5, equity_pos.cost_basis) + + self.assertEqual(10, future_pos.cost_basis) + + # send a $5k commission to the futures position. since self.FUTURE3 + # has a contract size (multipler) of 1000, this should result in a + # $10.5 updated cost basis. (5000 / 1000 = $5, spread out over 10 + # contracts, is $0.50 extra per contract). + future_pos.adjust_commission_cost_basis(self.FUTURE3, 5000) + self.assertEqual(10.5, future_pos.cost_basis) + def test_update_positions(self): pt = perf.PositionTracker(self.env.asset_finder, None) dt = pd.Timestamp("2014/01/01 3:00PM") diff --git a/zipline/finance/performance/position.py b/zipline/finance/performance/position.py index 24bc5bb4..a73ee913 100644 --- a/zipline/finance/performance/position.py +++ b/zipline/finance/performance/position.py @@ -173,7 +173,11 @@ class Position(object): return prev_cost = self.cost_basis * self.amount - new_cost = prev_cost + cost + if isinstance(asset, Future): + cost_to_use = cost / asset.multiplier + else: + cost_to_use = cost + new_cost = prev_cost + cost_to_use self.cost_basis = new_cost / self.amount def __repr__(self):