mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 06:32:09 +08:00
MAINT: Make tracker stats a method.
Instead of calling a function, where the only parameter is the tracker object, make it a method, so that the snapshot of position tracker stats can be more easily called as `pt.stats()`.
This commit is contained in:
@@ -35,7 +35,6 @@ from six.moves import range, zip
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.finance.performance as perf
|
||||
from zipline.finance.performance import position_tracker
|
||||
from zipline.finance.transaction import Transaction, create_transaction
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
@@ -2209,7 +2208,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
np.bool_(False)
|
||||
"""
|
||||
pt = perf.PositionTracker(self.env.asset_finder)
|
||||
pos_stats = position_tracker.calc_position_stats(pt)
|
||||
pos_stats = pt.stats()
|
||||
|
||||
stats = [
|
||||
'net_value',
|
||||
@@ -2263,7 +2262,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
|
||||
# Test long-only methods
|
||||
|
||||
pos_stats = position_tracker.calc_position_stats(pt)
|
||||
pos_stats = pt.stats()
|
||||
self.assertEqual(100, pos_stats.long_value)
|
||||
self.assertEqual(100 + 300000, pos_stats.long_exposure)
|
||||
self.assertEqual(2, pos_stats.longs_count)
|
||||
|
||||
@@ -91,7 +91,6 @@ import zipline.protocol as zp
|
||||
from zipline.utils.serialization_utils import (
|
||||
VERSION_LABEL
|
||||
)
|
||||
from zipline.finance.performance.position_tracker import calc_position_stats
|
||||
|
||||
log = logbook.Logger('Performance')
|
||||
TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE
|
||||
@@ -210,7 +209,7 @@ class PerformancePeriod(object):
|
||||
|
||||
def calculate_performance(self):
|
||||
pt = self.position_tracker
|
||||
pos_stats = calc_position_stats(pt)
|
||||
pos_stats = pt.stats()
|
||||
self.ending_value = pos_stats.net_value
|
||||
self.ending_exposure = pos_stats.net_exposure
|
||||
|
||||
@@ -279,7 +278,7 @@ class PerformancePeriod(object):
|
||||
return self.position_tracker.position_amounts
|
||||
|
||||
def __core_dict(self):
|
||||
pos_stats = calc_position_stats(self.position_tracker)
|
||||
pos_stats = self.position_tracker.stats()
|
||||
period_stats = calc_period_stats(pos_stats, self.ending_cash)
|
||||
|
||||
rval = {
|
||||
@@ -384,7 +383,7 @@ class PerformancePeriod(object):
|
||||
account = self._account_store
|
||||
|
||||
pt = self.position_tracker
|
||||
pos_stats = calc_position_stats(pt)
|
||||
pos_stats = pt.stats()
|
||||
period_stats = calc_period_stats(pos_stats, self.ending_cash)
|
||||
|
||||
# If no attribute is found on the PerformancePeriod resort to the
|
||||
|
||||
@@ -121,53 +121,6 @@ def calc_gross_value(long_value, short_value):
|
||||
return long_value + abs(short_value)
|
||||
|
||||
|
||||
def calc_position_stats(pt):
|
||||
amounts = []
|
||||
last_sale_prices = []
|
||||
for pos in itervalues(pt.positions):
|
||||
amounts.append(pos.amount)
|
||||
last_sale_prices.append(pos.last_sale_price)
|
||||
|
||||
position_value_multipliers = pt._position_value_multipliers
|
||||
position_exposure_multipliers = pt._position_exposure_multipliers
|
||||
|
||||
position_values = calc_position_values(
|
||||
amounts,
|
||||
last_sale_prices,
|
||||
position_value_multipliers
|
||||
)
|
||||
|
||||
position_exposures = calc_position_exposures(
|
||||
amounts,
|
||||
last_sale_prices,
|
||||
position_exposure_multipliers
|
||||
)
|
||||
|
||||
long_value = calc_long_value(position_values)
|
||||
short_value = calc_short_value(position_values)
|
||||
gross_value = calc_gross_value(long_value, short_value)
|
||||
long_exposure = calc_long_exposure(position_exposures)
|
||||
short_exposure = calc_short_exposure(position_exposures)
|
||||
gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
|
||||
net_exposure = calc_net(position_exposures)
|
||||
longs_count = calc_longs_count(position_exposures)
|
||||
shorts_count = calc_shorts_count(position_exposures)
|
||||
net_value = calc_net(position_values)
|
||||
|
||||
return PositionStats(
|
||||
long_value=long_value,
|
||||
gross_value=gross_value,
|
||||
short_value=short_value,
|
||||
long_exposure=long_exposure,
|
||||
short_exposure=short_exposure,
|
||||
gross_exposure=gross_exposure,
|
||||
net_exposure=net_exposure,
|
||||
longs_count=longs_count,
|
||||
shorts_count=shorts_count,
|
||||
net_value=net_value
|
||||
)
|
||||
|
||||
|
||||
class PositionTracker(object):
|
||||
|
||||
def __init__(self, asset_finder):
|
||||
@@ -456,6 +409,49 @@ class PositionTracker(object):
|
||||
positions.append(pos.to_dict())
|
||||
return positions
|
||||
|
||||
def stats(self):
|
||||
amounts = []
|
||||
last_sale_prices = []
|
||||
for pos in itervalues(self.positions):
|
||||
amounts.append(pos.amount)
|
||||
last_sale_prices.append(pos.last_sale_price)
|
||||
|
||||
position_values = calc_position_values(
|
||||
amounts,
|
||||
last_sale_prices,
|
||||
self._position_value_multipliers
|
||||
)
|
||||
|
||||
position_exposures = calc_position_exposures(
|
||||
amounts,
|
||||
last_sale_prices,
|
||||
self._position_exposure_multipliers
|
||||
)
|
||||
|
||||
long_value = calc_long_value(position_values)
|
||||
short_value = calc_short_value(position_values)
|
||||
gross_value = calc_gross_value(long_value, short_value)
|
||||
long_exposure = calc_long_exposure(position_exposures)
|
||||
short_exposure = calc_short_exposure(position_exposures)
|
||||
gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
|
||||
net_exposure = calc_net(position_exposures)
|
||||
longs_count = calc_longs_count(position_exposures)
|
||||
shorts_count = calc_shorts_count(position_exposures)
|
||||
net_value = calc_net(position_values)
|
||||
|
||||
return PositionStats(
|
||||
long_value=long_value,
|
||||
gross_value=gross_value,
|
||||
short_value=short_value,
|
||||
long_exposure=long_exposure,
|
||||
short_exposure=short_exposure,
|
||||
gross_exposure=gross_exposure,
|
||||
net_exposure=net_exposure,
|
||||
longs_count=longs_count,
|
||||
shorts_count=shorts_count,
|
||||
net_value=net_value
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
state_dict = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user