MAINT: Refactored unit tests to remove duplication

of arguments to factory.create methods.

Also added checking of the perf period cost basis results after each txn.
This commit is contained in:
Richard Frank
2014-03-05 11:25:51 -05:00
parent e7ec629510
commit e459c2729c
+18 -15
View File
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import collections
import logging
import operator
@@ -23,7 +25,7 @@ import datetime
import pytz
import itertools
from six.moves import range
from six.moves import range, zip
import zipline.utils.factory as factory
import zipline.finance.performance as perf
@@ -925,26 +927,23 @@ shares in position"
)
def test_cost_basis_calc(self):
trades = factory.create_trade_history(
1,
[10, 11, 11, 12],
[100, 100, 100, 100],
onesec,
self.sim_params
)
transactions = factory.create_txn_history(
history_args = (
1,
[10, 11, 11, 12],
[100, 100, 100, 100],
onesec,
self.sim_params
)
trades = factory.create_trade_history(*history_args)
transactions = factory.create_txn_history(*history_args)
pp = perf.PerformancePeriod(1000.0)
for txn in transactions:
average_cost = 0
for i, txn in enumerate(transactions):
pp.execute_transaction(txn)
average_cost = (average_cost * i + txn.price) / (i + 1)
self.assertEqual(pp.positions[1].cost_basis, average_cost)
for trade in trades:
pp.update_last_sale(trade)
@@ -975,14 +974,14 @@ shares in position"
100,
trades[-1].dt + onesec)
saleTxn = create_txn(
sale_txn = create_txn(
down_tick,
10.0,
-100)
pp.rollover()
pp.execute_transaction(saleTxn)
pp.execute_transaction(sale_txn)
pp.update_last_sale(down_tick)
pp.calculate_performance()
@@ -1003,9 +1002,13 @@ shares in position"
pp3 = perf.PerformancePeriod(1000.0)
transactions.append(saleTxn)
for txn in transactions:
average_cost = 0
for i, txn in enumerate(transactions):
pp3.execute_transaction(txn)
average_cost = (average_cost * i + txn.price) / (i + 1)
self.assertEqual(pp3.positions[1].cost_basis, average_cost)
pp3.execute_transaction(sale_txn)
trades.append(down_tick)
for trade in trades: