mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 20:11:46 +08:00
BUG: Fix cost basis calculation
Cost basis calculation now takes direction of txn into account. Closing a long position or covering a short shouldn't affect the cost basis.
This commit is contained in:
@@ -993,9 +993,9 @@ shares in position"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
round(pp.positions[1].cost_basis, 2),
|
||||
11.33,
|
||||
"should have a cost basis of 11.33"
|
||||
pp.positions[1].cost_basis,
|
||||
11,
|
||||
"should have a cost basis of 11"
|
||||
)
|
||||
|
||||
self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400")
|
||||
@@ -1022,9 +1022,9 @@ shares in position"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
round(pp3.positions[1].cost_basis, 2),
|
||||
11.33,
|
||||
"should have a cost basis of 11.33"
|
||||
pp3.positions[1].cost_basis,
|
||||
11,
|
||||
"should have a cost basis of 11"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -1033,6 +1033,32 @@ shares in position"
|
||||
"should be -400 for all trades and transactions in period"
|
||||
)
|
||||
|
||||
def test_cost_basis_calc_close_pos(self):
|
||||
history_args = (
|
||||
1,
|
||||
[10, 9, 11, 8, 9, 12, 13, 14],
|
||||
[200, -100, -100, 100, -300, 100, 500, 400],
|
||||
onesec,
|
||||
self.sim_params
|
||||
)
|
||||
cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
|
||||
|
||||
trades = factory.create_trade_history(*history_args)
|
||||
transactions = factory.create_txn_history(*history_args)
|
||||
|
||||
pp = perf.PerformancePeriod(1000.0)
|
||||
|
||||
for txn, cb in zip(transactions, cost_bases):
|
||||
pp.execute_transaction(txn)
|
||||
self.assertEqual(pp.positions[1].cost_basis, cb)
|
||||
|
||||
for trade in trades:
|
||||
pp.update_last_sale(trade)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
self.assertEqual(pp.positions[1].cost_basis, cost_bases[-1])
|
||||
|
||||
|
||||
class TestPerformanceTracker(unittest.TestCase):
|
||||
|
||||
|
||||
@@ -136,17 +136,27 @@ class Position(object):
|
||||
raise Exception('updating position with txn for a '
|
||||
'different sid')
|
||||
|
||||
# we're covering a short or closing a position
|
||||
if(self.amount + txn.amount == 0):
|
||||
total_shares = self.amount + txn.amount
|
||||
|
||||
if total_shares == 0:
|
||||
self.cost_basis = 0.0
|
||||
self.amount = 0
|
||||
else:
|
||||
prev_cost = self.cost_basis * self.amount
|
||||
txn_cost = txn.amount * txn.price
|
||||
total_cost = prev_cost + txn_cost
|
||||
total_shares = self.amount + txn.amount
|
||||
self.cost_basis = total_cost / total_shares
|
||||
self.amount = total_shares
|
||||
prev_direction = math.copysign(1, self.amount)
|
||||
txn_direction = math.copysign(1, txn.amount)
|
||||
|
||||
if prev_direction != txn_direction:
|
||||
# we're covering a short or closing a position
|
||||
if abs(txn.amount) > abs(self.amount):
|
||||
# we've closed the position and gone short
|
||||
# or covered the short position and gone long
|
||||
self.cost_basis = txn.price
|
||||
else:
|
||||
prev_cost = self.cost_basis * self.amount
|
||||
txn_cost = txn.amount * txn.price
|
||||
total_cost = prev_cost + txn_cost
|
||||
self.cost_basis = total_cost / total_shares
|
||||
|
||||
self.amount = total_shares
|
||||
|
||||
def adjust_commission_cost_basis(self, commission):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user