From d07d42263a94cb093c381f598ecf2cd4ad1771d5 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Thu, 17 Dec 2015 15:51:14 -0500 Subject: [PATCH] MAINT: Make tracker stats a method. Instead of calling a function, where the only parameter is the tracker object, make it a method, so that the snapshot of position tracker stats can be more easily called as `pt.stats()`. --- tests/test_perf_tracking.py | 5 +- zipline/finance/performance/period.py | 7 +- .../finance/performance/position_tracker.py | 90 +++++++++---------- 3 files changed, 48 insertions(+), 54 deletions(-) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index bb3a766f..976d247c 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -35,7 +35,6 @@ from six.moves import range, zip import zipline.utils.factory as factory import zipline.finance.performance as perf -from zipline.finance.performance import position_tracker from zipline.finance.transaction import Transaction, create_transaction import zipline.utils.math_utils as zp_math @@ -2209,7 +2208,7 @@ class TestPositionTracker(unittest.TestCase): np.bool_(False) """ pt = perf.PositionTracker(self.env.asset_finder) - pos_stats = position_tracker.calc_position_stats(pt) + pos_stats = pt.stats() stats = [ 'net_value', @@ -2263,7 +2262,7 @@ class TestPositionTracker(unittest.TestCase): # Test long-only methods - pos_stats = position_tracker.calc_position_stats(pt) + pos_stats = pt.stats() self.assertEqual(100, pos_stats.long_value) self.assertEqual(100 + 300000, pos_stats.long_exposure) self.assertEqual(2, pos_stats.longs_count) diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 67dec7a7..7f98ca94 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -91,7 +91,6 @@ import zipline.protocol as zp from zipline.utils.serialization_utils import ( VERSION_LABEL ) -from zipline.finance.performance.position_tracker import calc_position_stats log = logbook.Logger('Performance') TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE @@ -210,7 +209,7 @@ class PerformancePeriod(object): def calculate_performance(self): pt = self.position_tracker - pos_stats = calc_position_stats(pt) + pos_stats = pt.stats() self.ending_value = pos_stats.net_value self.ending_exposure = pos_stats.net_exposure @@ -279,7 +278,7 @@ class PerformancePeriod(object): return self.position_tracker.position_amounts def __core_dict(self): - pos_stats = calc_position_stats(self.position_tracker) + pos_stats = self.position_tracker.stats() period_stats = calc_period_stats(pos_stats, self.ending_cash) rval = { @@ -384,7 +383,7 @@ class PerformancePeriod(object): account = self._account_store pt = self.position_tracker - pos_stats = calc_position_stats(pt) + pos_stats = pt.stats() period_stats = calc_period_stats(pos_stats, self.ending_cash) # If no attribute is found on the PerformancePeriod resort to the diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index 04e3b276..eac750be 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -121,53 +121,6 @@ def calc_gross_value(long_value, short_value): return long_value + abs(short_value) -def calc_position_stats(pt): - amounts = [] - last_sale_prices = [] - for pos in itervalues(pt.positions): - amounts.append(pos.amount) - last_sale_prices.append(pos.last_sale_price) - - position_value_multipliers = pt._position_value_multipliers - position_exposure_multipliers = pt._position_exposure_multipliers - - position_values = calc_position_values( - amounts, - last_sale_prices, - position_value_multipliers - ) - - position_exposures = calc_position_exposures( - amounts, - last_sale_prices, - position_exposure_multipliers - ) - - long_value = calc_long_value(position_values) - short_value = calc_short_value(position_values) - gross_value = calc_gross_value(long_value, short_value) - long_exposure = calc_long_exposure(position_exposures) - short_exposure = calc_short_exposure(position_exposures) - gross_exposure = calc_gross_exposure(long_exposure, short_exposure) - net_exposure = calc_net(position_exposures) - longs_count = calc_longs_count(position_exposures) - shorts_count = calc_shorts_count(position_exposures) - net_value = calc_net(position_values) - - return PositionStats( - long_value=long_value, - gross_value=gross_value, - short_value=short_value, - long_exposure=long_exposure, - short_exposure=short_exposure, - gross_exposure=gross_exposure, - net_exposure=net_exposure, - longs_count=longs_count, - shorts_count=shorts_count, - net_value=net_value - ) - - class PositionTracker(object): def __init__(self, asset_finder): @@ -456,6 +409,49 @@ class PositionTracker(object): positions.append(pos.to_dict()) return positions + def stats(self): + amounts = [] + last_sale_prices = [] + for pos in itervalues(self.positions): + amounts.append(pos.amount) + last_sale_prices.append(pos.last_sale_price) + + position_values = calc_position_values( + amounts, + last_sale_prices, + self._position_value_multipliers + ) + + position_exposures = calc_position_exposures( + amounts, + last_sale_prices, + self._position_exposure_multipliers + ) + + long_value = calc_long_value(position_values) + short_value = calc_short_value(position_values) + gross_value = calc_gross_value(long_value, short_value) + long_exposure = calc_long_exposure(position_exposures) + short_exposure = calc_short_exposure(position_exposures) + gross_exposure = calc_gross_exposure(long_exposure, short_exposure) + net_exposure = calc_net(position_exposures) + longs_count = calc_longs_count(position_exposures) + shorts_count = calc_shorts_count(position_exposures) + net_value = calc_net(position_values) + + return PositionStats( + long_value=long_value, + gross_value=gross_value, + short_value=short_value, + long_exposure=long_exposure, + short_exposure=short_exposure, + gross_exposure=gross_exposure, + net_exposure=net_exposure, + longs_count=longs_count, + shorts_count=shorts_count, + net_value=net_value + ) + def __getstate__(self): state_dict = {}