MAINT: Don't store data portal everywhere

Removed lots of data portal references that participated in ref cycles
and prevented deterministic cleanup of dbs.
This commit is contained in:
Richard Frank
2016-04-08 07:28:30 -04:00
parent 8b610a2ab7
commit 70befd490b
8 changed files with 107 additions and 154 deletions
+1 -1
View File
@@ -286,7 +286,7 @@ class FinanceTestCase(TestCase):
else:
alternator = 1
tracker = PerformanceTracker(sim_params, self.env, data_portal)
tracker = PerformanceTracker(sim_params, self.env)
# replicate what tradesim does by going through every minute or day
# of the simulation and processing open orders each time
+49 -81
View File
@@ -50,7 +50,7 @@ from zipline.utils.serialization_utils import (
loads_with_persistent_ids, dumps_with_persistent_ids
)
from zipline.testing.core import create_data_portal_from_trade_history, \
create_empty_splits_mergers_frame, FakeDataPortal
create_empty_splits_mergers_frame
logger = logging.getLogger('Test Perf Tracking')
@@ -167,7 +167,7 @@ def calculate_results(sim_params,
splits = splits or {}
commissions = commissions or {}
perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal)
perf_tracker = perf.PerformanceTracker(sim_params, env)
results = []
@@ -189,8 +189,10 @@ def calculate_results(sim_params,
except KeyError:
pass
msg = perf_tracker.handle_market_close_daily(date)
perf_tracker.position_tracker.sync_last_sale_prices(date, False)
msg = perf_tracker.handle_market_close_daily(date, data_portal)
perf_tracker.position_tracker.sync_last_sale_prices(
date, False, data_portal,
)
msg['account'] = perf_tracker.get_account(True)
results.append(copy.deepcopy(msg))
return results
@@ -265,9 +267,7 @@ class TestSplitPerformance(unittest.TestCase):
def test_multiple_splits(self):
# if multiple positions all have splits at the same time, verify that
# the total leftover cash is correct
perf_tracker = perf.PerformanceTracker(
self.sim_params, self.env, FakeDataPortal()
)
perf_tracker = perf.PerformanceTracker(self.sim_params, self.env)
asset1 = self.env.asset_finder.retrieve_asset(1)
asset2 = self.env.asset_finder.retrieve_asset(2)
@@ -1240,11 +1240,10 @@ class TestPositionPerformance(unittest.TestCase):
txn1 = create_txn(self.asset1, trades_1[0].dt, 10.0, 100)
txn2 = create_txn(self.asset2, trades_1[0].dt, 10.0, -100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn1)
pp.handle_execution(txn1)
@@ -1252,7 +1251,7 @@ class TestPositionPerformance(unittest.TestCase):
pp.handle_execution(txn2)
dt = trades_1[-2].dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -1280,7 +1279,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
dt = trades_1[-1].dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -1333,11 +1332,10 @@ class TestPositionPerformance(unittest.TestCase):
self.sim_params,
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, 1000)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
@@ -1355,7 +1353,7 @@ class TestPositionPerformance(unittest.TestCase):
shorts_count=0)
# Validate that the account attributes were updated.
pt.sync_last_sale_prices(trades[-2].dt, False)
pt.sync_last_sale_prices(trades[-2].dt, False, data_portal)
# Validate that the account attributes were updated.
account = pp.as_account()
@@ -1373,7 +1371,7 @@ class TestPositionPerformance(unittest.TestCase):
net_liquidation=1000.0)
# now simulate a price jump to $11
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1425,11 +1423,10 @@ class TestPositionPerformance(unittest.TestCase):
self.sim_params,
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, 100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal,
period_open=self.sim_params.period_start,
period_close=self.sim_params.period_end)
pp.position_tracker = pt
@@ -1444,7 +1441,7 @@ class TestPositionPerformance(unittest.TestCase):
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[1].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1544,18 +1541,17 @@ single short-sale transaction"""
{1: trades})
txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(
1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades_1[-1].dt, False)
pt.sync_last_sale_prices(trades_1[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1611,7 +1607,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -1665,17 +1661,16 @@ single short-sale transaction"""
)
# now run a performance period encompassing the entire trade sample.
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
ptTotal = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
ppTotal.position_tracker = pt
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
ptTotal.sync_last_sale_prices(trades[-1].dt, False)
ptTotal.sync_last_sale_prices(trades[-1].dt, False, data_portal)
ppTotal.calculate_performance()
@@ -1778,11 +1773,10 @@ cost of sole txn in test"
)
txn = create_txn(self.asset3, trades[1].dt, 10.0, 1)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
@@ -1795,7 +1789,7 @@ cost of sole txn in test"
# stocks with a last sale price of 0.
self.assertEqual(pp.positions[3].last_sale_price, 10.0)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -1899,17 +1893,16 @@ single short-sale transaction"""
trades_1 = trades[:-2]
txn = create_txn(self.asset3, trades[0].dt, 10.0, -1)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(txn)
pp.handle_execution(txn)
pt.sync_last_sale_prices(trades[-3].dt, False)
pt.sync_last_sale_prices(trades[-3].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -1969,7 +1962,7 @@ single short-sale transaction"""
# simulate a rollover to a new period
pp.rollover()
pt.sync_last_sale_prices(trades_2[-1].dt, False)
pt.sync_last_sale_prices(trades_2[-1].dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -2027,21 +2020,20 @@ single short-sale transaction"""
)
# now run a performance period encompassing the entire trade sample.
ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal,
ptTotal = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
ppTotal.position_tracker = ptTotal
for trade in trades_1:
ptTotal.sync_last_sale_prices(trade.dt, False)
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
ptTotal.execute_transaction(txn)
ppTotal.handle_execution(txn)
for trade in trades_2:
ptTotal.sync_last_sale_prices(trade.dt, False)
ptTotal.sync_last_sale_prices(trade.dt, False, data_portal)
ppTotal.calculate_performance()
@@ -2144,11 +2136,10 @@ trade after cover"""
short_txn = create_txn(self.asset1, trades[1].dt, 10.0, -100)
cover_txn = create_txn(self.asset1, trades[6].dt, 7.0, 100)
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp.position_tracker = pt
pt.execute_transaction(short_txn)
@@ -2156,7 +2147,7 @@ trade after cover"""
pt.execute_transaction(cover_txn)
pp.handle_execution(cover_txn)
pt.sync_last_sale_prices(trades[-1].dt, False)
pt.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp.calculate_performance()
@@ -2231,13 +2222,12 @@ shares in position"
self.sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(
1000.0,
self.env.asset_finder,
self.sim_params.data_frequency,
data_portal,
period_open=self.sim_params.period_start,
period_close=self.sim_params.trading_days[-1]
)
@@ -2264,7 +2254,7 @@ shares in position"
"should have a cost basis of 11"
)
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
@@ -2281,7 +2271,7 @@ shares in position"
pp.handle_execution(sale_txn)
dt = down_tick.dt
pt.sync_last_sale_prices(dt, False)
pt.sync_last_sale_prices(dt, False, data_portal)
pp.calculate_performance()
self.assertEqual(
@@ -2299,11 +2289,10 @@ shares in position"
self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400")
pt3 = perf.PositionTracker(self.env.asset_finder, data_portal,
pt3 = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency,
data_portal)
self.sim_params.data_frequency)
pp3.position_tracker = pt3
average_cost = 0
@@ -2317,7 +2306,7 @@ shares in position"
pp3.handle_execution(sale_txn)
trades.append(down_tick)
pt3.sync_last_sale_prices(trades[-1].dt, False)
pt3.sync_last_sale_prices(trades[-1].dt, False, data_portal)
pp3.calculate_performance()
self.assertEqual(
@@ -2351,18 +2340,11 @@ shares in position"
)
cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5]
trades = factory.create_trade_history(*history_args)
transactions = factory.create_txn_history(*history_args)
data_portal = create_data_portal_from_trade_history(
self.env,
self.tempdir,
self.sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
self.sim_params.data_frequency)
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal,
pp = perf.PerformancePeriod(1000.0, self.env.asset_finder,
self.sim_params.data_frequency)
pp.position_tracker = pt
@@ -2413,22 +2395,8 @@ class TestPositionTracker(unittest.TestCase):
sim_params = factory.create_simulation_parameters(
num_days=4, env=self.env
)
trades = factory.create_trade_history(
1,
[10, 10, 10, 11],
[100, 100, 100, 100],
oneday,
sim_params,
env=self.env
)
data_portal = create_data_portal_from_trade_history(
self.env,
self.tempdir,
sim_params,
{1: trades})
pt = perf.PositionTracker(self.env.asset_finder, data_portal,
pt = perf.PositionTracker(self.env.asset_finder,
sim_params.data_frequency)
pos_stats = pt.stats()
@@ -2450,7 +2418,7 @@ class TestPositionTracker(unittest.TestCase):
self.assertNotIsInstance(val, (bool, np.bool_))
def test_position_values_and_exposures(self):
pt = perf.PositionTracker(self.env.asset_finder, None, None)
pt = perf.PositionTracker(self.env.asset_finder, None)
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(10.0),
last_sale_date=dt, last_sale_price=10)
@@ -2482,7 +2450,7 @@ class TestPositionTracker(unittest.TestCase):
self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure)
def test_update_positions(self):
pt = perf.PositionTracker(self.env.asset_finder, None, None)
pt = perf.PositionTracker(self.env.asset_finder, None)
dt = pd.Timestamp("2014/01/01 3:00PM")
pos1 = perf.Position(1, amount=np.float64(10.0),
last_sale_date=dt, last_sale_price=10)
+11 -9
View File
@@ -492,7 +492,6 @@ class TradingAlgorithm(object):
self.perf_tracker = PerformanceTracker(
sim_params=self.sim_params,
env=self.trading_environment,
data_portal=self.data_portal
)
# Set the dt initially to the period start by forcing it to change.
@@ -603,14 +602,17 @@ class TradingAlgorithm(object):
# Create zipline and loop through simulated_trading.
# Each iteration returns a perf dictionary
perfs = []
for perf in self.get_generator():
perfs.append(perf)
try:
perfs = []
for perf in self.get_generator():
perfs.append(perf)
# convert perf dict to pandas dataframe
daily_stats = self._create_daily_stats(perfs)
# convert perf dict to pandas dataframe
daily_stats = self._create_daily_stats(perfs)
self.analyze(daily_stats)
self.analyze(daily_stats)
finally:
self.data_portal = None
return daily_stats
@@ -1057,7 +1059,7 @@ class TradingAlgorithm(object):
def updated_portfolio(self):
if self.portfolio_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self.datetime, self._in_before_trading_start, self.data_portal)
self._portfolio = \
self.perf_tracker.get_portfolio(self.performance_needs_update)
self.portfolio_needs_update = False
@@ -1071,7 +1073,7 @@ class TradingAlgorithm(object):
def updated_account(self):
if self.account_needs_update:
self.perf_tracker.position_tracker.sync_last_sale_prices(
self.datetime, self._in_before_trading_start)
self.datetime, self._in_before_trading_start, self.data_portal)
self._account = \
self.perf_tracker.get_account(self.performance_needs_update)
self.account_needs_update = False
-3
View File
@@ -133,7 +133,6 @@ class PerformancePeriod(object):
starting_cash,
asset_finder,
data_frequency,
data_portal,
period_open=None,
period_close=None,
keep_transactions=True,
@@ -144,8 +143,6 @@ class PerformancePeriod(object):
self.asset_finder = asset_finder
self.data_frequency = data_frequency
self._data_portal = data_portal
self.period_open = period_open
self.period_close = period_close
@@ -119,14 +119,9 @@ def calc_gross_value(long_value, short_value):
class PositionTracker(object):
def __init__(self, asset_finder, data_portal, data_frequency):
def __init__(self, asset_finder, data_frequency):
self.asset_finder = asset_finder
# FIXME really want to avoid storing a data portal here,
# but the path to get to maybe_create_close_position_transaction
# is long and tortuous
self._data_portal = data_portal
# sid => position object
self.positions = positiondict()
# Arrays for quick calculations of positions value
@@ -316,12 +311,12 @@ class PositionTracker(object):
return net_cash_payment
def maybe_create_close_position_transaction(self, asset, dt):
def maybe_create_close_position_transaction(self, asset, dt, data_portal):
if not self.positions.get(asset):
return None
amount = self.positions.get(asset).amount
price = self._data_portal.get_spot_value(
price = data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency)
# Get the last traded price if price is no longer available
@@ -372,8 +367,8 @@ class PositionTracker(object):
positions.append(pos.to_dict())
return positions
def sync_last_sale_prices(self, dt, handle_non_market_minutes):
data_portal = self._data_portal
def sync_last_sale_prices(self, dt, handle_non_market_minutes,
data_portal):
if not handle_non_market_minutes:
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_spot_value(
+20 -25
View File
@@ -78,7 +78,7 @@ class PerformanceTracker(object):
"""
Tracks the performance of the algorithm.
"""
def __init__(self, sim_params, env, data_portal):
def __init__(self, sim_params, env):
self.sim_params = sim_params
self.env = env
@@ -101,15 +101,8 @@ class PerformanceTracker(object):
self.trading_days = all_trading_days[mask]
self._data_portal = data_portal
if data_portal is not None:
self._adjustment_reader = data_portal._adjustment_reader
else:
self._adjustment_reader = None
self.position_tracker = PositionTracker(
asset_finder=env.asset_finder,
data_portal=data_portal,
data_frequency=self.sim_params.data_frequency)
if self.emission_rate == 'daily':
@@ -132,7 +125,6 @@ class PerformanceTracker(object):
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the cumulative period will be calculated over the entire test.
period_open=self.period_start,
period_close=self.period_end,
@@ -152,7 +144,6 @@ class PerformanceTracker(object):
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the daily period will be calculated for the market day
period_open=self.market_open,
period_close=self.market_close,
@@ -264,13 +255,13 @@ class PerformanceTracker(object):
self.cumulative_performance.handle_commission(cost)
self.todays_performance.handle_commission(cost)
def process_close_position(self, asset, dt):
def process_close_position(self, asset, dt, data_portal):
txn = self.position_tracker.\
maybe_create_close_position_transaction(asset, dt)
maybe_create_close_position_transaction(asset, dt, data_portal)
if txn:
self.process_transaction(txn)
def check_upcoming_dividends(self, next_trading_day):
def check_upcoming_dividends(self, next_trading_day, adjustment_reader):
"""
Check if we currently own any stocks with dividends whose ex_date is
the next trading day. Track how much we should be payed on those
@@ -280,7 +271,7 @@ class PerformanceTracker(object):
is the next trading day. Apply all such benefits, then recalculate
performance.
"""
if self._adjustment_reader is None:
if adjustment_reader is None:
return
position_tracker = self.position_tracker
held_sids = set(position_tracker.positions)
@@ -291,10 +282,10 @@ class PerformanceTracker(object):
if held_sids:
asset_finder = self.env.asset_finder
cash_dividends = self._adjustment_reader.\
cash_dividends = adjustment_reader.\
get_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
stock_dividends = self._adjustment_reader.\
stock_dividends = adjustment_reader.\
get_stock_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
@@ -310,7 +301,7 @@ class PerformanceTracker(object):
self.cumulative_performance.handle_dividends_paid(net_cash_payment)
self.todays_performance.handle_dividends_paid(net_cash_payment)
def handle_minute_close(self, dt):
def handle_minute_close(self, dt, data_portal):
"""
Handles the close of the given minute. This includes handling
market-close functions if the given minute is the end of the market
@@ -327,7 +318,7 @@ class PerformanceTracker(object):
A tuple of the minute perf packet and daily perf packet.
If the market day has not ended, the daily perf packet is None.
"""
self.position_tracker.sync_last_sale_prices(dt, False)
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
self.update_performance()
todays_date = normalize_date(dt)
account = self.get_account(False)
@@ -346,16 +337,18 @@ class PerformanceTracker(object):
# if this is the close, update dividends for the next day.
# Return the performance tuple
if dt == self.market_close:
return minute_packet, self._handle_market_close(todays_date)
return minute_packet, self._handle_market_close(
todays_date, data_portal._adjustment_reader,
)
else:
return minute_packet, None
def handle_market_close_daily(self, dt):
def handle_market_close_daily(self, dt, data_portal):
"""
Function called after handle_data when running with daily emission
rate.
"""
self.position_tracker.sync_last_sale_prices(dt, False)
self.position_tracker.sync_last_sale_prices(dt, False, data_portal)
self.update_performance()
completed_date = self.day
account = self.get_account(False)
@@ -368,11 +361,12 @@ class PerformanceTracker(object):
benchmark_value,
account.leverage)
daily_packet = self._handle_market_close(completed_date)
daily_packet = self._handle_market_close(
completed_date, data_portal._adjustment_reader,
)
return daily_packet
def _handle_market_close(self, completed_date):
def _handle_market_close(self, completed_date, adjustment_reader):
# increment the day counter before we move markers forward.
self.day_count += 1.0
@@ -406,7 +400,8 @@ class PerformanceTracker(object):
return daily_update
# Check for any dividends, then return the daily perf packet
self.check_upcoming_dividends(next_trading_day=next_trading_day)
self.check_upcoming_dividends(next_trading_day=next_trading_day,
adjustment_reader=adjustment_reader)
return daily_update
def handle_simulation_end(self):
+20 -23
View File
@@ -97,20 +97,9 @@ class AlgorithmSimulator(object):
Main generator work loop.
"""
algo = self.algo
algo.data_portal = self.data_portal
handle_data = algo.event_manager.handle_data
current_data = self.current_data
data_portal = self.data_portal
# can't cache a pointer to algo.perf_tracker because we're not
# guaranteed that the algo doesn't swap out perf trackers during
# its lifetime.
# likewise, we can't cache a pointer to the blotter.
algo.perf_tracker.position_tracker.data_portal = data_portal
def every_bar(dt_to_use):
def every_bar(dt_to_use, current_data=self.current_data,
handle_data=algo.event_manager.handle_data):
# called every tick (minute or day).
self.simulation_dt = dt_to_use
@@ -152,7 +141,8 @@ class AlgorithmSimulator(object):
self.algo.account_needs_update = True
self.algo.performance_needs_update = True
def once_a_day(midnight_dt):
def once_a_day(midnight_dt, current_data=self.current_data,
data_portal=self.data_portal):
# Get the positions before updating the date so that prices are
# fetched for trading close instead of midnight
positions = algo.perf_tracker.position_tracker.positions
@@ -183,11 +173,15 @@ class AlgorithmSimulator(object):
# call before trading start
algo.before_trading_start(current_data)
def handle_benchmark(date):
def handle_benchmark(date, benchmark_source=self.benchmark_source):
algo.perf_tracker.all_benchmark_returns[date] = \
self.benchmark_source.get_value(date)
benchmark_source.get_value(date)
def on_exit():
self.benchmark_source = self.current_data = self.data_portal = None
with ExitStack() as stack:
stack.callback(on_exit)
stack.enter_context(self.processor)
stack.enter_context(ZiplineAPI(self.algo))
@@ -245,8 +239,9 @@ class AlgorithmSimulator(object):
assets_to_clear = \
[asset for asset in position_assets if past_auto_close_date(asset)]
perf_tracker = algo.perf_tracker
data_portal = self.data_portal
for asset in assets_to_clear:
perf_tracker.process_close_position(asset, dt)
perf_tracker.process_close_position(asset, dt, data_portal)
# Remove open orders for any sids that have reached their
# auto_close_date.
@@ -257,23 +252,25 @@ class AlgorithmSimulator(object):
for asset in assets_to_cancel:
blotter.cancel_all_orders_for_asset(asset)
@staticmethod
def _get_daily_message(dt, algo, perf_tracker):
def _get_daily_message(self, dt, algo, perf_tracker):
"""
Get a perf message for the given datetime.
"""
perf_message = perf_tracker.handle_market_close_daily(dt)
perf_message = perf_tracker.handle_market_close_daily(
dt, self.data_portal,
)
perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars
return perf_message
@staticmethod
def _get_minute_message(dt, algo, perf_tracker):
def _get_minute_message(self, dt, algo, perf_tracker):
"""
Get a perf message for the given datetime.
"""
rvars = algo.recorded_vars
minute_message, daily_message = perf_tracker.handle_minute_close(dt)
minute_message, daily_message = perf_tracker.handle_minute_close(
dt, self.data_portal,
)
minute_message['minute_perf']['recorded_vars'] = rvars
if daily_message:
+1 -2
View File
@@ -16,9 +16,8 @@ import pandas as pd
from .utils.enum import enum
from zipline._protocol import BarData as _BarData
from zipline._protocol import BarData # noqa
BarData = _BarData
# Datasource type should completely determine the other fields of a
# message with its type.