BUG: Fix cost basis calculation

Cost basis calculation now takes direction of txn into account.

Closing a long position or covering a short shouldn't affect the cost basis.
This commit is contained in:
Richard Frank
2014-03-04 13:35:44 -05:00
parent e459c2729c
commit 5020c36f8d
2 changed files with 51 additions and 15 deletions
+32 -6
View File
@@ -993,9 +993,9 @@ shares in position"
)
self.assertEqual(
round(pp.positions[1].cost_basis, 2),
11.33,
"should have a cost basis of 11.33"
pp.positions[1].cost_basis,
11,
"should have a cost basis of 11"
)
self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400")
@@ -1022,9 +1022,9 @@ shares in position"
)
self.assertEqual(
round(pp3.positions[1].cost_basis, 2),
11.33,
"should have a cost basis of 11.33"
pp3.positions[1].cost_basis,
11,
"should have a cost basis of 11"
)
self.assertEqual(
@@ -1033,6 +1033,32 @@ shares in position"
"should be -400 for all trades and transactions in period"
)
def test_cost_basis_calc_close_pos(self):
history_args = (
1,
[10, 9, 11, 8, 9, 12, 13, 14],
[200, -100, -100, 100, -300, 100, 500, 400],
onesec,
self.sim_params
)
cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
trades = factory.create_trade_history(*history_args)
transactions = factory.create_txn_history(*history_args)
pp = perf.PerformancePeriod(1000.0)
for txn, cb in zip(transactions, cost_bases):
pp.execute_transaction(txn)
self.assertEqual(pp.positions[1].cost_basis, cb)
for trade in trades:
pp.update_last_sale(trade)
pp.calculate_performance()
self.assertEqual(pp.positions[1].cost_basis, cost_bases[-1])
class TestPerformanceTracker(unittest.TestCase):
+19 -9
View File
@@ -136,17 +136,27 @@ class Position(object):
raise Exception('updating position with txn for a '
'different sid')
# we're covering a short or closing a position
if(self.amount + txn.amount == 0):
total_shares = self.amount + txn.amount
if total_shares == 0:
self.cost_basis = 0.0
self.amount = 0
else:
prev_cost = self.cost_basis * self.amount
txn_cost = txn.amount * txn.price
total_cost = prev_cost + txn_cost
total_shares = self.amount + txn.amount
self.cost_basis = total_cost / total_shares
self.amount = total_shares
prev_direction = math.copysign(1, self.amount)
txn_direction = math.copysign(1, txn.amount)
if prev_direction != txn_direction:
# we're covering a short or closing a position
if abs(txn.amount) > abs(self.amount):
# we've closed the position and gone short
# or covered the short position and gone long
self.cost_basis = txn.price
else:
prev_cost = self.cost_basis * self.amount
txn_cost = txn.amount * txn.price
total_cost = prev_cost + txn_cost
self.cost_basis = total_cost / total_shares
self.amount = total_shares
def adjust_commission_cost_basis(self, commission):
"""