mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 15:42:03 +08:00
MAINT: Don't store data portal everywhere
Removed lots of data portal references that participated in ref cycles and prevented deterministic cleanup of dbs.
This commit is contained in:
@@ -286,7 +286,7 @@ class FinanceTestCase(TestCase):
|
||||
else:
|
||||
alternator = 1
|
||||
|
||||
tracker = PerformanceTracker(sim_params, self.env, data_portal)
|
||||
tracker = PerformanceTracker(sim_params, self.env)
|
||||
|
||||
# replicate what tradesim does by going through every minute or day
|
||||
# of the simulation and processing open orders each time
|
||||
|
||||
+49
-81
@@ -50,7 +50,7 @@ from zipline.utils.serialization_utils import (
|
||||
loads_with_persistent_ids, dumps_with_persistent_ids
|
||||
)
|
||||
from zipline.testing.core import create_data_portal_from_trade_history, \
|
||||
create_empty_splits_mergers_frame, FakeDataPortal
|
||||
create_empty_splits_mergers_frame
|
||||
|
||||
logger = logging.getLogger('Test Perf Tracking')
|
||||
|
||||
@@ -167,7 +167,7 @@ def calculate_results(sim_params,
|
||||
splits = splits or {}
|
||||
commissions = commissions or {}
|
||||
|
||||
perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal)
|
||||
perf_tracker = perf.PerformanceTracker(sim_params, env)
|
||||
|
||||
results = []
|
||||
|
||||
@@ -189,8 +189,10 @@ def calculate_results(sim_params,
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
msg = perf_tracker.handle_market_close_daily(date)
|
||||
perf_tracker.position_tracker.sync_last_sale_prices(date, False)
|
||||
msg = perf_tracker.handle_market_close_daily(date, data_portal)
|
||||
perf_tracker.position_tracker.sync_last_sale_prices(
|
||||
date, False, data_portal,
|
||||
)
|
||||
msg['account'] = perf_tracker.get_account(True)
|
||||
results.append(copy.deepcopy(msg))
|
||||
return results
|
||||
@@ -265,9 +267,7 @@ class TestSplitPerformance(unittest.TestCase):
|
||||
def test_multiple_splits(self):
|
||||
# if multiple positions all have splits at the same time, verify that
|
||||
# the total leftover cash is correct
|
||||
perf_tracker = perf.PerformanceTracker(
|
||||
self.sim_params, self.env, FakeDataPortal()
|
||||
)
|
||||
perf_tracker = perf.PerformanceTracker(self.sim_params, self.env)
|
||||
|
||||
asset1 = self.env.asset_finder.retrieve_asset(1)
|
||||
asset2 = self.env.asset_finder.retrieve_asset(2)
|
||||
@@ -1240,11 +1240,10 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
txn1 = create_txn(self.asset1, trades_1[0].dt, 10.0, 100)
|
||||
txn2 = create_txn(self.asset2, trades_1[0].dt, 10.0, -100)
|
||||
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
pt.execute_transaction(txn1)
|
||||
pp.handle_execution(txn1)
|
||||
@@ -1252,7 +1251,7 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
pp.handle_execution(txn2)
|
||||
|
||||
dt = trades_1[-2].dt
|
||||
pt.sync_last_sale_prices(dt, False)
|
||||
pt.sync_last_sale_prices(dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1280,7 +1279,7 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
net_liquidation=1000.0)
|
||||
|
||||
dt = trades_1[-1].dt
|
||||
pt.sync_last_sale_prices(dt, False)
|
||||
pt.sync_last_sale_prices(dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1333,11 +1332,10 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
self.sim_params,
|
||||
{1: trades})
|
||||
txn = create_txn(self.asset1, trades[1].dt, 10.0, 1000)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
pt.execute_transaction(txn)
|
||||
@@ -1355,7 +1353,7 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
shorts_count=0)
|
||||
|
||||
# Validate that the account attributes were updated.
|
||||
pt.sync_last_sale_prices(trades[-2].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-2].dt, False, data_portal)
|
||||
|
||||
# Validate that the account attributes were updated.
|
||||
account = pp.as_account()
|
||||
@@ -1373,7 +1371,7 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
net_liquidation=1000.0)
|
||||
|
||||
# now simulate a price jump to $11
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1425,11 +1423,10 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
self.sim_params,
|
||||
{1: trades})
|
||||
txn = create_txn(self.asset1, trades[1].dt, 10.0, 100)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal,
|
||||
period_open=self.sim_params.period_start,
|
||||
period_close=self.sim_params.period_end)
|
||||
pp.position_tracker = pt
|
||||
@@ -1444,7 +1441,7 @@ class TestPositionPerformance(unittest.TestCase):
|
||||
# stocks with a last sale price of 0.
|
||||
self.assertEqual(pp.positions[1].last_sale_price, 10.0)
|
||||
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1544,18 +1541,17 @@ single short-sale transaction"""
|
||||
{1: trades})
|
||||
|
||||
txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(
|
||||
1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
pt.execute_transaction(txn)
|
||||
pp.handle_execution(txn)
|
||||
|
||||
pt.sync_last_sale_prices(trades_1[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades_1[-1].dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1611,7 +1607,7 @@ single short-sale transaction"""
|
||||
# simulate a rollover to a new period
|
||||
pp.rollover()
|
||||
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -1665,17 +1661,16 @@ single short-sale transaction"""
|
||||
)
|
||||
|
||||
# now run a performance period encompassing the entire trade sample.
|
||||
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
ptTotal = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
ppTotal.position_tracker = pt
|
||||
|
||||
ptTotal.execute_transaction(txn)
|
||||
ppTotal.handle_execution(txn)
|
||||
|
||||
ptTotal.sync_last_sale_prices(trades[-1].dt, False)
|
||||
ptTotal.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
ppTotal.calculate_performance()
|
||||
|
||||
@@ -1778,11 +1773,10 @@ cost of sole txn in test"
|
||||
)
|
||||
|
||||
txn = create_txn(self.asset3, trades[1].dt, 10.0, 1)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
pt.execute_transaction(txn)
|
||||
@@ -1795,7 +1789,7 @@ cost of sole txn in test"
|
||||
# stocks with a last sale price of 0.
|
||||
self.assertEqual(pp.positions[3].last_sale_price, 10.0)
|
||||
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
pp.calculate_performance()
|
||||
|
||||
self.assertEqual(
|
||||
@@ -1899,17 +1893,16 @@ single short-sale transaction"""
|
||||
trades_1 = trades[:-2]
|
||||
|
||||
txn = create_txn(self.asset3, trades[0].dt, 10.0, -1)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
pt.execute_transaction(txn)
|
||||
pp.handle_execution(txn)
|
||||
|
||||
pt.sync_last_sale_prices(trades[-3].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-3].dt, False, data_portal)
|
||||
pp.calculate_performance()
|
||||
|
||||
self.assertEqual(
|
||||
@@ -1969,7 +1962,7 @@ single short-sale transaction"""
|
||||
# simulate a rollover to a new period
|
||||
pp.rollover()
|
||||
|
||||
pt.sync_last_sale_prices(trades_2[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades_2[-1].dt, False, data_portal)
|
||||
pp.calculate_performance()
|
||||
|
||||
self.assertEqual(
|
||||
@@ -2027,21 +2020,20 @@ single short-sale transaction"""
|
||||
)
|
||||
|
||||
# now run a performance period encompassing the entire trade sample.
|
||||
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
ptTotal = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
ppTotal.position_tracker = ptTotal
|
||||
|
||||
for trade in trades_1:
|
||||
ptTotal.sync_last_sale_prices(trade.dt, False)
|
||||
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
|
||||
|
||||
ptTotal.execute_transaction(txn)
|
||||
ppTotal.handle_execution(txn)
|
||||
|
||||
for trade in trades_2:
|
||||
ptTotal.sync_last_sale_prices(trade.dt, False)
|
||||
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
|
||||
|
||||
ppTotal.calculate_performance()
|
||||
|
||||
@@ -2144,11 +2136,10 @@ trade after cover"""
|
||||
|
||||
short_txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
|
||||
cover_txn = create_txn(self.asset1, trades[6].dt, 7.0, 100)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
pt.execute_transaction(short_txn)
|
||||
@@ -2156,7 +2147,7 @@ trade after cover"""
|
||||
pt.execute_transaction(cover_txn)
|
||||
pp.handle_execution(cover_txn)
|
||||
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -2231,13 +2222,12 @@ shares in position"
|
||||
self.sim_params,
|
||||
{1: trades})
|
||||
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(
|
||||
1000.0,
|
||||
self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal,
|
||||
period_open=self.sim_params.period_start,
|
||||
period_close=self.sim_params.trading_days[-1]
|
||||
)
|
||||
@@ -2264,7 +2254,7 @@ shares in position"
|
||||
"should have a cost basis of 11"
|
||||
)
|
||||
|
||||
pt.sync_last_sale_prices(dt, False)
|
||||
pt.sync_last_sale_prices(dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
|
||||
@@ -2281,7 +2271,7 @@ shares in position"
|
||||
pp.handle_execution(sale_txn)
|
||||
|
||||
dt = down_tick.dt
|
||||
pt.sync_last_sale_prices(dt, False)
|
||||
pt.sync_last_sale_prices(dt, False, data_portal)
|
||||
|
||||
pp.calculate_performance()
|
||||
self.assertEqual(
|
||||
@@ -2299,11 +2289,10 @@ shares in position"
|
||||
|
||||
self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400")
|
||||
|
||||
pt3 = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt3 = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency,
|
||||
data_portal)
|
||||
self.sim_params.data_frequency)
|
||||
pp3.position_tracker = pt3
|
||||
|
||||
average_cost = 0
|
||||
@@ -2317,7 +2306,7 @@ shares in position"
|
||||
pp3.handle_execution(sale_txn)
|
||||
|
||||
trades.append(down_tick)
|
||||
pt3.sync_last_sale_prices(trades[-1].dt, False)
|
||||
pt3.sync_last_sale_prices(trades[-1].dt, False, data_portal)
|
||||
|
||||
pp3.calculate_performance()
|
||||
self.assertEqual(
|
||||
@@ -2351,18 +2340,11 @@ shares in position"
|
||||
)
|
||||
cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
|
||||
|
||||
trades = factory.create_trade_history(*history_args)
|
||||
transactions = factory.create_txn_history(*history_args)
|
||||
|
||||
data_portal = create_data_portal_from_trade_history(
|
||||
self.env,
|
||||
self.tempdir,
|
||||
self.sim_params,
|
||||
{1: trades})
|
||||
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal,
|
||||
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
|
||||
self.sim_params.data_frequency)
|
||||
pp.position_tracker = pt
|
||||
|
||||
@@ -2413,22 +2395,8 @@ class TestPositionTracker(unittest.TestCase):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
num_days=4, env=self.env
|
||||
)
|
||||
trades = factory.create_trade_history(
|
||||
1,
|
||||
[10, 10, 10, 11],
|
||||
[100, 100, 100, 100],
|
||||
oneday,
|
||||
sim_params,
|
||||
env=self.env
|
||||
)
|
||||
|
||||
data_portal = create_data_portal_from_trade_history(
|
||||
self.env,
|
||||
self.tempdir,
|
||||
sim_params,
|
||||
{1: trades})
|
||||
|
||||
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
|
||||
pt = perf.PositionTracker(self.env.asset_finder,
|
||||
sim_params.data_frequency)
|
||||
pos_stats = pt.stats()
|
||||
|
||||
@@ -2450,7 +2418,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
self.assertNotIsInstance(val, (bool, np.bool_))
|
||||
|
||||
def test_position_values_and_exposures(self):
|
||||
pt = perf.PositionTracker(self.env.asset_finder, None, None)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, None)
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(10.0),
|
||||
last_sale_date=dt, last_sale_price=10)
|
||||
@@ -2482,7 +2450,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure)
|
||||
|
||||
def test_update_positions(self):
|
||||
pt = perf.PositionTracker(self.env.asset_finder, None, None)
|
||||
pt = perf.PositionTracker(self.env.asset_finder, None)
|
||||
dt = pd.Timestamp("2014/01/01 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(10.0),
|
||||
last_sale_date=dt, last_sale_price=10)
|
||||
|
||||
+11
-9
@@ -492,7 +492,6 @@ class TradingAlgorithm(object):
|
||||
self.perf_tracker = PerformanceTracker(
|
||||
sim_params=self.sim_params,
|
||||
env=self.trading_environment,
|
||||
data_portal=self.data_portal
|
||||
)
|
||||
|
||||
# Set the dt initially to the period start by forcing it to change.
|
||||
@@ -603,14 +602,17 @@ class TradingAlgorithm(object):
|
||||
|
||||
# Create zipline and loop through simulated_trading.
|
||||
# Each iteration returns a perf dictionary
|
||||
perfs = []
|
||||
for perf in self.get_generator():
|
||||
perfs.append(perf)
|
||||
try:
|
||||
perfs = []
|
||||
for perf in self.get_generator():
|
||||
perfs.append(perf)
|
||||
|
||||
# convert perf dict to pandas dataframe
|
||||
daily_stats = self._create_daily_stats(perfs)
|
||||
# convert perf dict to pandas dataframe
|
||||
daily_stats = self._create_daily_stats(perfs)
|
||||
|
||||
self.analyze(daily_stats)
|
||||
self.analyze(daily_stats)
|
||||
finally:
|
||||
self.data_portal = None
|
||||
|
||||
return daily_stats
|
||||
|
||||
@@ -1057,7 +1059,7 @@ class TradingAlgorithm(object):
|
||||
def updated_portfolio(self):
|
||||
if self.portfolio_needs_update:
|
||||
self.perf_tracker.position_tracker.sync_last_sale_prices(
|
||||
self.datetime, self._in_before_trading_start)
|
||||
self.datetime, self._in_before_trading_start, self.data_portal)
|
||||
self._portfolio = \
|
||||
self.perf_tracker.get_portfolio(self.performance_needs_update)
|
||||
self.portfolio_needs_update = False
|
||||
@@ -1071,7 +1073,7 @@ class TradingAlgorithm(object):
|
||||
def updated_account(self):
|
||||
if self.account_needs_update:
|
||||
self.perf_tracker.position_tracker.sync_last_sale_prices(
|
||||
self.datetime, self._in_before_trading_start)
|
||||
self.datetime, self._in_before_trading_start, self.data_portal)
|
||||
self._account = \
|
||||
self.perf_tracker.get_account(self.performance_needs_update)
|
||||
self.account_needs_update = False
|
||||
|
||||
@@ -133,7 +133,6 @@ class PerformancePeriod(object):
|
||||
starting_cash,
|
||||
asset_finder,
|
||||
data_frequency,
|
||||
data_portal,
|
||||
period_open=None,
|
||||
period_close=None,
|
||||
keep_transactions=True,
|
||||
@@ -144,8 +143,6 @@ class PerformancePeriod(object):
|
||||
self.asset_finder = asset_finder
|
||||
self.data_frequency = data_frequency
|
||||
|
||||
self._data_portal = data_portal
|
||||
|
||||
self.period_open = period_open
|
||||
self.period_close = period_close
|
||||
|
||||
|
||||
@@ -119,14 +119,9 @@ def calc_gross_value(long_value, short_value):
|
||||
|
||||
class PositionTracker(object):
|
||||
|
||||
def __init__(self, asset_finder, data_portal, data_frequency):
|
||||
def __init__(self, asset_finder, data_frequency):
|
||||
self.asset_finder = asset_finder
|
||||
|
||||
# FIXME really want to avoid storing a data portal here,
|
||||
# but the path to get to maybe_create_close_position_transaction
|
||||
# is long and tortuous
|
||||
self._data_portal = data_portal
|
||||
|
||||
# sid => position object
|
||||
self.positions = positiondict()
|
||||
# Arrays for quick calculations of positions value
|
||||
@@ -316,12 +311,12 @@ class PositionTracker(object):
|
||||
|
||||
return net_cash_payment
|
||||
|
||||
def maybe_create_close_position_transaction(self, asset, dt):
|
||||
def maybe_create_close_position_transaction(self, asset, dt, data_portal):
|
||||
if not self.positions.get(asset):
|
||||
return None
|
||||
|
||||
amount = self.positions.get(asset).amount
|
||||
price = self._data_portal.get_spot_value(
|
||||
price = data_portal.get_spot_value(
|
||||
asset, 'price', dt, self.data_frequency)
|
||||
|
||||
# Get the last traded price if price is no longer available
|
||||
@@ -372,8 +367,8 @@ class PositionTracker(object):
|
||||
positions.append(pos.to_dict())
|
||||
return positions
|
||||
|
||||
def sync_last_sale_prices(self, dt, handle_non_market_minutes):
|
||||
data_portal = self._data_portal
|
||||
def sync_last_sale_prices(self, dt, handle_non_market_minutes,
|
||||
data_portal):
|
||||
if not handle_non_market_minutes:
|
||||
for asset, position in iteritems(self.positions):
|
||||
last_sale_price = data_portal.get_spot_value(
|
||||
|
||||
@@ -78,7 +78,7 @@ class PerformanceTracker(object):
|
||||
"""
|
||||
Tracks the performance of the algorithm.
|
||||
"""
|
||||
def __init__(self, sim_params, env, data_portal):
|
||||
def __init__(self, sim_params, env):
|
||||
self.sim_params = sim_params
|
||||
self.env = env
|
||||
|
||||
@@ -101,15 +101,8 @@ class PerformanceTracker(object):
|
||||
|
||||
self.trading_days = all_trading_days[mask]
|
||||
|
||||
self._data_portal = data_portal
|
||||
if data_portal is not None:
|
||||
self._adjustment_reader = data_portal._adjustment_reader
|
||||
else:
|
||||
self._adjustment_reader = None
|
||||
|
||||
self.position_tracker = PositionTracker(
|
||||
asset_finder=env.asset_finder,
|
||||
data_portal=data_portal,
|
||||
data_frequency=self.sim_params.data_frequency)
|
||||
|
||||
if self.emission_rate == 'daily':
|
||||
@@ -132,7 +125,6 @@ class PerformanceTracker(object):
|
||||
# initial cash is your capital base.
|
||||
starting_cash=self.capital_base,
|
||||
data_frequency=self.sim_params.data_frequency,
|
||||
data_portal=data_portal,
|
||||
# the cumulative period will be calculated over the entire test.
|
||||
period_open=self.period_start,
|
||||
period_close=self.period_end,
|
||||
@@ -152,7 +144,6 @@ class PerformanceTracker(object):
|
||||
# initial cash is your capital base.
|
||||
starting_cash=self.capital_base,
|
||||
data_frequency=self.sim_params.data_frequency,
|
||||
data_portal=data_portal,
|
||||
# the daily period will be calculated for the market day
|
||||
period_open=self.market_open,
|
||||
period_close=self.market_close,
|
||||
@@ -264,13 +255,13 @@ class PerformanceTracker(object):
|
||||
self.cumulative_performance.handle_commission(cost)
|
||||
self.todays_performance.handle_commission(cost)
|
||||
|
||||
def process_close_position(self, asset, dt):
|
||||
def process_close_position(self, asset, dt, data_portal):
|
||||
txn = self.position_tracker.\
|
||||
maybe_create_close_position_transaction(asset, dt)
|
||||
maybe_create_close_position_transaction(asset, dt, data_portal)
|
||||
if txn:
|
||||
self.process_transaction(txn)
|
||||
|
||||
def check_upcoming_dividends(self, next_trading_day):
|
||||
def check_upcoming_dividends(self, next_trading_day, adjustment_reader):
|
||||
"""
|
||||
Check if we currently own any stocks with dividends whose ex_date is
|
||||
the next trading day. Track how much we should be payed on those
|
||||
@@ -280,7 +271,7 @@ class PerformanceTracker(object):
|
||||
is the next trading day. Apply all such benefits, then recalculate
|
||||
performance.
|
||||
"""
|
||||
if self._adjustment_reader is None:
|
||||
if adjustment_reader is None:
|
||||
return
|
||||
position_tracker = self.position_tracker
|
||||
held_sids = set(position_tracker.positions)
|
||||
@@ -291,10 +282,10 @@ class PerformanceTracker(object):
|
||||
if held_sids:
|
||||
asset_finder = self.env.asset_finder
|
||||
|
||||
cash_dividends = self._adjustment_reader.\
|
||||
cash_dividends = adjustment_reader.\
|
||||
get_dividends_with_ex_date(held_sids, next_trading_day,
|
||||
asset_finder)
|
||||
stock_dividends = self._adjustment_reader.\
|
||||
stock_dividends = adjustment_reader.\
|
||||
get_stock_dividends_with_ex_date(held_sids, next_trading_day,
|
||||
asset_finder)
|
||||
|
||||
@@ -310,7 +301,7 @@ class PerformanceTracker(object):
|
||||
self.cumulative_performance.handle_dividends_paid(net_cash_payment)
|
||||
self.todays_performance.handle_dividends_paid(net_cash_payment)
|
||||
|
||||
def handle_minute_close(self, dt):
|
||||
def handle_minute_close(self, dt, data_portal):
|
||||
"""
|
||||
Handles the close of the given minute. This includes handling
|
||||
market-close functions if the given minute is the end of the market
|
||||
@@ -327,7 +318,7 @@ class PerformanceTracker(object):
|
||||
A tuple of the minute perf packet and daily perf packet.
|
||||
If the market day has not ended, the daily perf packet is None.
|
||||
"""
|
||||
self.position_tracker.sync_last_sale_prices(dt, False)
|
||||
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
|
||||
self.update_performance()
|
||||
todays_date = normalize_date(dt)
|
||||
account = self.get_account(False)
|
||||
@@ -346,16 +337,18 @@ class PerformanceTracker(object):
|
||||
# if this is the close, update dividends for the next day.
|
||||
# Return the performance tuple
|
||||
if dt == self.market_close:
|
||||
return minute_packet, self._handle_market_close(todays_date)
|
||||
return minute_packet, self._handle_market_close(
|
||||
todays_date, data_portal._adjustment_reader,
|
||||
)
|
||||
else:
|
||||
return minute_packet, None
|
||||
|
||||
def handle_market_close_daily(self, dt):
|
||||
def handle_market_close_daily(self, dt, data_portal):
|
||||
"""
|
||||
Function called after handle_data when running with daily emission
|
||||
rate.
|
||||
"""
|
||||
self.position_tracker.sync_last_sale_prices(dt, False)
|
||||
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
|
||||
self.update_performance()
|
||||
completed_date = self.day
|
||||
account = self.get_account(False)
|
||||
@@ -368,11 +361,12 @@ class PerformanceTracker(object):
|
||||
benchmark_value,
|
||||
account.leverage)
|
||||
|
||||
daily_packet = self._handle_market_close(completed_date)
|
||||
|
||||
daily_packet = self._handle_market_close(
|
||||
completed_date, data_portal._adjustment_reader,
|
||||
)
|
||||
return daily_packet
|
||||
|
||||
def _handle_market_close(self, completed_date):
|
||||
def _handle_market_close(self, completed_date, adjustment_reader):
|
||||
|
||||
# increment the day counter before we move markers forward.
|
||||
self.day_count += 1.0
|
||||
@@ -406,7 +400,8 @@ class PerformanceTracker(object):
|
||||
return daily_update
|
||||
|
||||
# Check for any dividends, then return the daily perf packet
|
||||
self.check_upcoming_dividends(next_trading_day=next_trading_day)
|
||||
self.check_upcoming_dividends(next_trading_day=next_trading_day,
|
||||
adjustment_reader=adjustment_reader)
|
||||
return daily_update
|
||||
|
||||
def handle_simulation_end(self):
|
||||
|
||||
@@ -97,20 +97,9 @@ class AlgorithmSimulator(object):
|
||||
Main generator work loop.
|
||||
"""
|
||||
algo = self.algo
|
||||
algo.data_portal = self.data_portal
|
||||
handle_data = algo.event_manager.handle_data
|
||||
current_data = self.current_data
|
||||
|
||||
data_portal = self.data_portal
|
||||
|
||||
# can't cache a pointer to algo.perf_tracker because we're not
|
||||
# guaranteed that the algo doesn't swap out perf trackers during
|
||||
# its lifetime.
|
||||
# likewise, we can't cache a pointer to the blotter.
|
||||
|
||||
algo.perf_tracker.position_tracker.data_portal = data_portal
|
||||
|
||||
def every_bar(dt_to_use):
|
||||
def every_bar(dt_to_use, current_data=self.current_data,
|
||||
handle_data=algo.event_manager.handle_data):
|
||||
# called every tick (minute or day).
|
||||
|
||||
self.simulation_dt = dt_to_use
|
||||
@@ -152,7 +141,8 @@ class AlgorithmSimulator(object):
|
||||
self.algo.account_needs_update = True
|
||||
self.algo.performance_needs_update = True
|
||||
|
||||
def once_a_day(midnight_dt):
|
||||
def once_a_day(midnight_dt, current_data=self.current_data,
|
||||
data_portal=self.data_portal):
|
||||
# Get the positions before updating the date so that prices are
|
||||
# fetched for trading close instead of midnight
|
||||
positions = algo.perf_tracker.position_tracker.positions
|
||||
@@ -183,11 +173,15 @@ class AlgorithmSimulator(object):
|
||||
# call before trading start
|
||||
algo.before_trading_start(current_data)
|
||||
|
||||
def handle_benchmark(date):
|
||||
def handle_benchmark(date, benchmark_source=self.benchmark_source):
|
||||
algo.perf_tracker.all_benchmark_returns[date] = \
|
||||
self.benchmark_source.get_value(date)
|
||||
benchmark_source.get_value(date)
|
||||
|
||||
def on_exit():
|
||||
self.benchmark_source = self.current_data = self.data_portal = None
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.callback(on_exit)
|
||||
stack.enter_context(self.processor)
|
||||
stack.enter_context(ZiplineAPI(self.algo))
|
||||
|
||||
@@ -245,8 +239,9 @@ class AlgorithmSimulator(object):
|
||||
assets_to_clear = \
|
||||
[asset for asset in position_assets if past_auto_close_date(asset)]
|
||||
perf_tracker = algo.perf_tracker
|
||||
data_portal = self.data_portal
|
||||
for asset in assets_to_clear:
|
||||
perf_tracker.process_close_position(asset, dt)
|
||||
perf_tracker.process_close_position(asset, dt, data_portal)
|
||||
|
||||
# Remove open orders for any sids that have reached their
|
||||
# auto_close_date.
|
||||
@@ -257,23 +252,25 @@ class AlgorithmSimulator(object):
|
||||
for asset in assets_to_cancel:
|
||||
blotter.cancel_all_orders_for_asset(asset)
|
||||
|
||||
@staticmethod
|
||||
def _get_daily_message(dt, algo, perf_tracker):
|
||||
def _get_daily_message(self, dt, algo, perf_tracker):
|
||||
"""
|
||||
Get a perf message for the given datetime.
|
||||
"""
|
||||
perf_message = perf_tracker.handle_market_close_daily(dt)
|
||||
perf_message = perf_tracker.handle_market_close_daily(
|
||||
dt, self.data_portal,
|
||||
)
|
||||
perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars
|
||||
return perf_message
|
||||
|
||||
@staticmethod
|
||||
def _get_minute_message(dt, algo, perf_tracker):
|
||||
def _get_minute_message(self, dt, algo, perf_tracker):
|
||||
"""
|
||||
Get a perf message for the given datetime.
|
||||
"""
|
||||
rvars = algo.recorded_vars
|
||||
|
||||
minute_message, daily_message = perf_tracker.handle_minute_close(dt)
|
||||
minute_message, daily_message = perf_tracker.handle_minute_close(
|
||||
dt, self.data_portal,
|
||||
)
|
||||
minute_message['minute_perf']['recorded_vars'] = rvars
|
||||
|
||||
if daily_message:
|
||||
|
||||
+1
-2
@@ -16,9 +16,8 @@ import pandas as pd
|
||||
|
||||
from .utils.enum import enum
|
||||
|
||||
from zipline._protocol import BarData as _BarData
|
||||
from zipline._protocol import BarData # noqa
|
||||
|
||||
BarData = _BarData
|
||||
|
||||
# Datasource type should completely determine the other fields of a
|
||||
# message with its type.
|
||||
|
||||
Reference in New Issue
Block a user