MAINT: Make tracker stats a method.

Instead of calling a function, where the only parameter is the tracker
object, make it a method, so that the snapshot of position tracker stats
can be more easily called as `pt.stats()`.
This commit is contained in:
Eddie Hebert
2015-12-17 15:51:14 -05:00
parent 4bb269490f
commit d07d42263a
3 changed files with 48 additions and 54 deletions
+2 -3
View File
@@ -35,7 +35,6 @@ from six.moves import range, zip
import zipline.utils.factory as factory
import zipline.finance.performance as perf
from zipline.finance.performance import position_tracker
from zipline.finance.transaction import Transaction, create_transaction
import zipline.utils.math_utils as zp_math
@@ -2209,7 +2208,7 @@ class TestPositionTracker(unittest.TestCase):
np.bool_(False)
"""
pt = perf.PositionTracker(self.env.asset_finder)
pos_stats = position_tracker.calc_position_stats(pt)
pos_stats = pt.stats()
stats = [
'net_value',
@@ -2263,7 +2262,7 @@ class TestPositionTracker(unittest.TestCase):
# Test long-only methods
pos_stats = position_tracker.calc_position_stats(pt)
pos_stats = pt.stats()
self.assertEqual(100, pos_stats.long_value)
self.assertEqual(100 + 300000, pos_stats.long_exposure)
self.assertEqual(2, pos_stats.longs_count)
+3 -4
View File
@@ -91,7 +91,6 @@ import zipline.protocol as zp
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
from zipline.finance.performance.position_tracker import calc_position_stats
log = logbook.Logger('Performance')
TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE
@@ -210,7 +209,7 @@ class PerformancePeriod(object):
def calculate_performance(self):
pt = self.position_tracker
pos_stats = calc_position_stats(pt)
pos_stats = pt.stats()
self.ending_value = pos_stats.net_value
self.ending_exposure = pos_stats.net_exposure
@@ -279,7 +278,7 @@ class PerformancePeriod(object):
return self.position_tracker.position_amounts
def __core_dict(self):
pos_stats = calc_position_stats(self.position_tracker)
pos_stats = self.position_tracker.stats()
period_stats = calc_period_stats(pos_stats, self.ending_cash)
rval = {
@@ -384,7 +383,7 @@ class PerformancePeriod(object):
account = self._account_store
pt = self.position_tracker
pos_stats = calc_position_stats(pt)
pos_stats = pt.stats()
period_stats = calc_period_stats(pos_stats, self.ending_cash)
# If no attribute is found on the PerformancePeriod resort to the
+43 -47
View File
@@ -121,53 +121,6 @@ def calc_gross_value(long_value, short_value):
return long_value + abs(short_value)
def calc_position_stats(pt):
amounts = []
last_sale_prices = []
for pos in itervalues(pt.positions):
amounts.append(pos.amount)
last_sale_prices.append(pos.last_sale_price)
position_value_multipliers = pt._position_value_multipliers
position_exposure_multipliers = pt._position_exposure_multipliers
position_values = calc_position_values(
amounts,
last_sale_prices,
position_value_multipliers
)
position_exposures = calc_position_exposures(
amounts,
last_sale_prices,
position_exposure_multipliers
)
long_value = calc_long_value(position_values)
short_value = calc_short_value(position_values)
gross_value = calc_gross_value(long_value, short_value)
long_exposure = calc_long_exposure(position_exposures)
short_exposure = calc_short_exposure(position_exposures)
gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
net_exposure = calc_net(position_exposures)
longs_count = calc_longs_count(position_exposures)
shorts_count = calc_shorts_count(position_exposures)
net_value = calc_net(position_values)
return PositionStats(
long_value=long_value,
gross_value=gross_value,
short_value=short_value,
long_exposure=long_exposure,
short_exposure=short_exposure,
gross_exposure=gross_exposure,
net_exposure=net_exposure,
longs_count=longs_count,
shorts_count=shorts_count,
net_value=net_value
)
class PositionTracker(object):
def __init__(self, asset_finder):
@@ -456,6 +409,49 @@ class PositionTracker(object):
positions.append(pos.to_dict())
return positions
def stats(self):
amounts = []
last_sale_prices = []
for pos in itervalues(self.positions):
amounts.append(pos.amount)
last_sale_prices.append(pos.last_sale_price)
position_values = calc_position_values(
amounts,
last_sale_prices,
self._position_value_multipliers
)
position_exposures = calc_position_exposures(
amounts,
last_sale_prices,
self._position_exposure_multipliers
)
long_value = calc_long_value(position_values)
short_value = calc_short_value(position_values)
gross_value = calc_gross_value(long_value, short_value)
long_exposure = calc_long_exposure(position_exposures)
short_exposure = calc_short_exposure(position_exposures)
gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
net_exposure = calc_net(position_exposures)
longs_count = calc_longs_count(position_exposures)
shorts_count = calc_shorts_count(position_exposures)
net_value = calc_net(position_values)
return PositionStats(
long_value=long_value,
gross_value=gross_value,
short_value=short_value,
long_exposure=long_exposure,
short_exposure=short_exposure,
gross_exposure=gross_exposure,
net_exposure=net_exposure,
longs_count=longs_count,
shorts_count=shorts_count,
net_value=net_value
)
def __getstate__(self):
state_dict = {}