diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 238f5b8a..e3ed169c 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -993,9 +993,9 @@ shares in position" ) self.assertEqual( - round(pp.positions[1].cost_basis, 2), - 11.33, - "should have a cost basis of 11.33" + pp.positions[1].cost_basis, + 11, + "should have a cost basis of 11" ) self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400") @@ -1022,9 +1022,9 @@ shares in position" ) self.assertEqual( - round(pp3.positions[1].cost_basis, 2), - 11.33, - "should have a cost basis of 11.33" + pp3.positions[1].cost_basis, + 11, + "should have a cost basis of 11" ) self.assertEqual( @@ -1033,6 +1033,32 @@ shares in position" "should be -400 for all trades and transactions in period" ) + def test_cost_basis_calc_close_pos(self): + history_args = ( + 1, + [10, 9, 11, 8, 9, 12, 13, 14], + [200, -100, -100, 100, -300, 100, 500, 400], + onesec, + self.sim_params + ) + cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5] + + trades = factory.create_trade_history(*history_args) + transactions = factory.create_txn_history(*history_args) + + pp = perf.PerformancePeriod(1000.0) + + for txn, cb in zip(transactions, cost_bases): + pp.execute_transaction(txn) + self.assertEqual(pp.positions[1].cost_basis, cb) + + for trade in trades: + pp.update_last_sale(trade) + + pp.calculate_performance() + + self.assertEqual(pp.positions[1].cost_basis, cost_bases[-1]) + class TestPerformanceTracker(unittest.TestCase): diff --git a/zipline/finance/performance/position.py b/zipline/finance/performance/position.py index 01f72d29..9227eb2d 100644 --- a/zipline/finance/performance/position.py +++ b/zipline/finance/performance/position.py @@ -136,17 +136,27 @@ class Position(object): raise Exception('updating position with txn for a ' 'different sid') - # we're covering a short or closing a position - if(self.amount + txn.amount == 0): + total_shares = self.amount + txn.amount + + if total_shares == 0: self.cost_basis = 0.0 - self.amount = 0 else: - prev_cost = self.cost_basis * self.amount - txn_cost = txn.amount * txn.price - total_cost = prev_cost + txn_cost - total_shares = self.amount + txn.amount - self.cost_basis = total_cost / total_shares - self.amount = total_shares + prev_direction = math.copysign(1, self.amount) + txn_direction = math.copysign(1, txn.amount) + + if prev_direction != txn_direction: + # we're covering a short or closing a position + if abs(txn.amount) > abs(self.amount): + # we've closed the position and gone short + # or covered the short position and gone long + self.cost_basis = txn.price + else: + prev_cost = self.cost_basis * self.amount + txn_cost = txn.amount * txn.price + total_cost = prev_cost + txn_cost + self.cost_basis = total_cost / total_shares + + self.amount = total_shares def adjust_commission_cost_basis(self, commission): """