diff --git a/tests/test_finance.py b/tests/test_finance.py index 142272a7..5ba73a96 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -286,7 +286,7 @@ class FinanceTestCase(TestCase): else: alternator = 1 - tracker = PerformanceTracker(sim_params, self.env, data_portal) + tracker = PerformanceTracker(sim_params, self.env) # replicate what tradesim does by going through every minute or day # of the simulation and processing open orders each time diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 11b3886e..4174ea52 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -50,7 +50,7 @@ from zipline.utils.serialization_utils import ( loads_with_persistent_ids, dumps_with_persistent_ids ) from zipline.testing.core import create_data_portal_from_trade_history, \ - create_empty_splits_mergers_frame, FakeDataPortal + create_empty_splits_mergers_frame logger = logging.getLogger('Test Perf Tracking') @@ -167,7 +167,7 @@ def calculate_results(sim_params, splits = splits or {} commissions = commissions or {} - perf_tracker = perf.PerformanceTracker(sim_params, env, data_portal) + perf_tracker = perf.PerformanceTracker(sim_params, env) results = [] @@ -189,8 +189,10 @@ def calculate_results(sim_params, except KeyError: pass - msg = perf_tracker.handle_market_close_daily(date) - perf_tracker.position_tracker.sync_last_sale_prices(date, False) + msg = perf_tracker.handle_market_close_daily(date, data_portal) + perf_tracker.position_tracker.sync_last_sale_prices( + date, False, data_portal, + ) msg['account'] = perf_tracker.get_account(True) results.append(copy.deepcopy(msg)) return results @@ -265,9 +267,7 @@ class TestSplitPerformance(unittest.TestCase): def test_multiple_splits(self): # if multiple positions all have splits at the same time, verify that # the total leftover cash is correct - perf_tracker = perf.PerformanceTracker( - self.sim_params, self.env, FakeDataPortal() - ) + perf_tracker = perf.PerformanceTracker(self.sim_params, self.env) asset1 = self.env.asset_finder.retrieve_asset(1) asset2 = self.env.asset_finder.retrieve_asset(2) @@ -1240,11 +1240,10 @@ class TestPositionPerformance(unittest.TestCase): txn1 = create_txn(self.asset1, trades_1[0].dt, 10.0, 100) txn2 = create_txn(self.asset2, trades_1[0].dt, 10.0, -100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn1) pp.handle_execution(txn1) @@ -1252,7 +1251,7 @@ class TestPositionPerformance(unittest.TestCase): pp.handle_execution(txn2) dt = trades_1[-2].dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -1280,7 +1279,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) dt = trades_1[-1].dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -1333,11 +1332,10 @@ class TestPositionPerformance(unittest.TestCase): self.sim_params, {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, 1000) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1355,7 +1353,7 @@ class TestPositionPerformance(unittest.TestCase): shorts_count=0) # Validate that the account attributes were updated. - pt.sync_last_sale_prices(trades[-2].dt, False) + pt.sync_last_sale_prices(trades[-2].dt, False, data_portal) # Validate that the account attributes were updated. account = pp.as_account() @@ -1373,7 +1371,7 @@ class TestPositionPerformance(unittest.TestCase): net_liquidation=1000.0) # now simulate a price jump to $11 - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1425,11 +1423,10 @@ class TestPositionPerformance(unittest.TestCase): self.sim_params, {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, 100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, self.sim_params.data_frequency, - data_portal, period_open=self.sim_params.period_start, period_close=self.sim_params.period_end) pp.position_tracker = pt @@ -1444,7 +1441,7 @@ class TestPositionPerformance(unittest.TestCase): # stocks with a last sale price of 0. self.assertEqual(pp.positions[1].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1544,18 +1541,17 @@ single short-sale transaction""" {1: trades}) txn = create_txn(self.asset1, trades[1].dt, 10.0, -100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod( 1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades_1[-1].dt, False) + pt.sync_last_sale_prices(trades_1[-1].dt, False, data_portal) pp.calculate_performance() @@ -1611,7 +1607,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -1665,17 +1661,16 @@ single short-sale transaction""" ) # now run a performance period encompassing the entire trade sample. - ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal, + ptTotal = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) ppTotal.position_tracker = pt ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) - ptTotal.sync_last_sale_prices(trades[-1].dt, False) + ptTotal.sync_last_sale_prices(trades[-1].dt, False, data_portal) ppTotal.calculate_performance() @@ -1778,11 +1773,10 @@ cost of sole txn in test" ) txn = create_txn(self.asset3, trades[1].dt, 10.0, 1) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1795,7 +1789,7 @@ cost of sole txn in test" # stocks with a last sale price of 0. self.assertEqual(pp.positions[3].last_sale_price, 10.0) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -1899,17 +1893,16 @@ single short-sale transaction""" trades_1 = trades[:-2] txn = create_txn(self.asset3, trades[0].dt, 10.0, -1) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(txn) pp.handle_execution(txn) - pt.sync_last_sale_prices(trades[-3].dt, False) + pt.sync_last_sale_prices(trades[-3].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -1969,7 +1962,7 @@ single short-sale transaction""" # simulate a rollover to a new period pp.rollover() - pt.sync_last_sale_prices(trades_2[-1].dt, False) + pt.sync_last_sale_prices(trades_2[-1].dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -2027,21 +2020,20 @@ single short-sale transaction""" ) # now run a performance period encompassing the entire trade sample. - ptTotal = perf.PositionTracker(self.env.asset_finder, data_portal, + ptTotal = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) ppTotal.position_tracker = ptTotal for trade in trades_1: - ptTotal.sync_last_sale_prices(trade.dt, False) + ptTotal.sync_last_sale_prices(trade.dt, False, data_portal) ptTotal.execute_transaction(txn) ppTotal.handle_execution(txn) for trade in trades_2: - ptTotal.sync_last_sale_prices(trade.dt, False) + ptTotal.sync_last_sale_prices(trade.dt, False, data_portal) ppTotal.calculate_performance() @@ -2144,11 +2136,10 @@ trade after cover""" short_txn = create_txn(self.asset1, trades[1].dt, 10.0, -100) cover_txn = create_txn(self.asset1, trades[6].dt, 7.0, 100) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp.position_tracker = pt pt.execute_transaction(short_txn) @@ -2156,7 +2147,7 @@ trade after cover""" pt.execute_transaction(cover_txn) pp.handle_execution(cover_txn) - pt.sync_last_sale_prices(trades[-1].dt, False) + pt.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp.calculate_performance() @@ -2231,13 +2222,12 @@ shares in position" self.sim_params, {1: trades}) - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp = perf.PerformancePeriod( 1000.0, self.env.asset_finder, self.sim_params.data_frequency, - data_portal, period_open=self.sim_params.period_start, period_close=self.sim_params.trading_days[-1] ) @@ -2264,7 +2254,7 @@ shares in position" "should have a cost basis of 11" ) - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() @@ -2281,7 +2271,7 @@ shares in position" pp.handle_execution(sale_txn) dt = down_tick.dt - pt.sync_last_sale_prices(dt, False) + pt.sync_last_sale_prices(dt, False, data_portal) pp.calculate_performance() self.assertEqual( @@ -2299,11 +2289,10 @@ shares in position" self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400") - pt3 = perf.PositionTracker(self.env.asset_finder, data_portal, + pt3 = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder, - self.sim_params.data_frequency, - data_portal) + self.sim_params.data_frequency) pp3.position_tracker = pt3 average_cost = 0 @@ -2317,7 +2306,7 @@ shares in position" pp3.handle_execution(sale_txn) trades.append(down_tick) - pt3.sync_last_sale_prices(trades[-1].dt, False) + pt3.sync_last_sale_prices(trades[-1].dt, False, data_portal) pp3.calculate_performance() self.assertEqual( @@ -2351,18 +2340,11 @@ shares in position" ) cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5] - trades = factory.create_trade_history(*history_args) transactions = factory.create_txn_history(*history_args) - data_portal = create_data_portal_from_trade_history( - self.env, - self.tempdir, - self.sim_params, - {1: trades}) - - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, self.sim_params.data_frequency) - pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, data_portal, + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder, self.sim_params.data_frequency) pp.position_tracker = pt @@ -2413,22 +2395,8 @@ class TestPositionTracker(unittest.TestCase): sim_params = factory.create_simulation_parameters( num_days=4, env=self.env ) - trades = factory.create_trade_history( - 1, - [10, 10, 10, 11], - [100, 100, 100, 100], - oneday, - sim_params, - env=self.env - ) - data_portal = create_data_portal_from_trade_history( - self.env, - self.tempdir, - sim_params, - {1: trades}) - - pt = perf.PositionTracker(self.env.asset_finder, data_portal, + pt = perf.PositionTracker(self.env.asset_finder, sim_params.data_frequency) pos_stats = pt.stats() @@ -2450,7 +2418,7 @@ class TestPositionTracker(unittest.TestCase): self.assertNotIsInstance(val, (bool, np.bool_)) def test_position_values_and_exposures(self): - pt = perf.PositionTracker(self.env.asset_finder, None, None) + pt = perf.PositionTracker(self.env.asset_finder, None) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), last_sale_date=dt, last_sale_price=10) @@ -2482,7 +2450,7 @@ class TestPositionTracker(unittest.TestCase): self.assertEqual(100 - 200 + 300000 - 400000, pos_stats.net_exposure) def test_update_positions(self): - pt = perf.PositionTracker(self.env.asset_finder, None, None) + pt = perf.PositionTracker(self.env.asset_finder, None) dt = pd.Timestamp("2014/01/01 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), last_sale_date=dt, last_sale_price=10) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 8e204710..98f75b50 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -492,7 +492,6 @@ class TradingAlgorithm(object): self.perf_tracker = PerformanceTracker( sim_params=self.sim_params, env=self.trading_environment, - data_portal=self.data_portal ) # Set the dt initially to the period start by forcing it to change. @@ -603,14 +602,17 @@ class TradingAlgorithm(object): # Create zipline and loop through simulated_trading. # Each iteration returns a perf dictionary - perfs = [] - for perf in self.get_generator(): - perfs.append(perf) + try: + perfs = [] + for perf in self.get_generator(): + perfs.append(perf) - # convert perf dict to pandas dataframe - daily_stats = self._create_daily_stats(perfs) + # convert perf dict to pandas dataframe + daily_stats = self._create_daily_stats(perfs) - self.analyze(daily_stats) + self.analyze(daily_stats) + finally: + self.data_portal = None return daily_stats @@ -1057,7 +1059,7 @@ class TradingAlgorithm(object): def updated_portfolio(self): if self.portfolio_needs_update: self.perf_tracker.position_tracker.sync_last_sale_prices( - self.datetime, self._in_before_trading_start) + self.datetime, self._in_before_trading_start, self.data_portal) self._portfolio = \ self.perf_tracker.get_portfolio(self.performance_needs_update) self.portfolio_needs_update = False @@ -1071,7 +1073,7 @@ class TradingAlgorithm(object): def updated_account(self): if self.account_needs_update: self.perf_tracker.position_tracker.sync_last_sale_prices( - self.datetime, self._in_before_trading_start) + self.datetime, self._in_before_trading_start, self.data_portal) self._account = \ self.perf_tracker.get_account(self.performance_needs_update) self.account_needs_update = False diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 1488d666..23a1864c 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -133,7 +133,6 @@ class PerformancePeriod(object): starting_cash, asset_finder, data_frequency, - data_portal, period_open=None, period_close=None, keep_transactions=True, @@ -144,8 +143,6 @@ class PerformancePeriod(object): self.asset_finder = asset_finder self.data_frequency = data_frequency - self._data_portal = data_portal - self.period_open = period_open self.period_close = period_close diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index 88d68dc5..ea1600a1 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -119,14 +119,9 @@ def calc_gross_value(long_value, short_value): class PositionTracker(object): - def __init__(self, asset_finder, data_portal, data_frequency): + def __init__(self, asset_finder, data_frequency): self.asset_finder = asset_finder - # FIXME really want to avoid storing a data portal here, - # but the path to get to maybe_create_close_position_transaction - # is long and tortuous - self._data_portal = data_portal - # sid => position object self.positions = positiondict() # Arrays for quick calculations of positions value @@ -316,12 +311,12 @@ class PositionTracker(object): return net_cash_payment - def maybe_create_close_position_transaction(self, asset, dt): + def maybe_create_close_position_transaction(self, asset, dt, data_portal): if not self.positions.get(asset): return None amount = self.positions.get(asset).amount - price = self._data_portal.get_spot_value( + price = data_portal.get_spot_value( asset, 'price', dt, self.data_frequency) # Get the last traded price if price is no longer available @@ -372,8 +367,8 @@ class PositionTracker(object): positions.append(pos.to_dict()) return positions - def sync_last_sale_prices(self, dt, handle_non_market_minutes): - data_portal = self._data_portal + def sync_last_sale_prices(self, dt, handle_non_market_minutes, + data_portal): if not handle_non_market_minutes: for asset, position in iteritems(self.positions): last_sale_price = data_portal.get_spot_value( diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index 6307416b..ff98ebb1 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -78,7 +78,7 @@ class PerformanceTracker(object): """ Tracks the performance of the algorithm. """ - def __init__(self, sim_params, env, data_portal): + def __init__(self, sim_params, env): self.sim_params = sim_params self.env = env @@ -101,15 +101,8 @@ class PerformanceTracker(object): self.trading_days = all_trading_days[mask] - self._data_portal = data_portal - if data_portal is not None: - self._adjustment_reader = data_portal._adjustment_reader - else: - self._adjustment_reader = None - self.position_tracker = PositionTracker( asset_finder=env.asset_finder, - data_portal=data_portal, data_frequency=self.sim_params.data_frequency) if self.emission_rate == 'daily': @@ -132,7 +125,6 @@ class PerformanceTracker(object): # initial cash is your capital base. starting_cash=self.capital_base, data_frequency=self.sim_params.data_frequency, - data_portal=data_portal, # the cumulative period will be calculated over the entire test. period_open=self.period_start, period_close=self.period_end, @@ -152,7 +144,6 @@ class PerformanceTracker(object): # initial cash is your capital base. starting_cash=self.capital_base, data_frequency=self.sim_params.data_frequency, - data_portal=data_portal, # the daily period will be calculated for the market day period_open=self.market_open, period_close=self.market_close, @@ -264,13 +255,13 @@ class PerformanceTracker(object): self.cumulative_performance.handle_commission(cost) self.todays_performance.handle_commission(cost) - def process_close_position(self, asset, dt): + def process_close_position(self, asset, dt, data_portal): txn = self.position_tracker.\ - maybe_create_close_position_transaction(asset, dt) + maybe_create_close_position_transaction(asset, dt, data_portal) if txn: self.process_transaction(txn) - def check_upcoming_dividends(self, next_trading_day): + def check_upcoming_dividends(self, next_trading_day, adjustment_reader): """ Check if we currently own any stocks with dividends whose ex_date is the next trading day. Track how much we should be payed on those @@ -280,7 +271,7 @@ class PerformanceTracker(object): is the next trading day. Apply all such benefits, then recalculate performance. """ - if self._adjustment_reader is None: + if adjustment_reader is None: return position_tracker = self.position_tracker held_sids = set(position_tracker.positions) @@ -291,10 +282,10 @@ class PerformanceTracker(object): if held_sids: asset_finder = self.env.asset_finder - cash_dividends = self._adjustment_reader.\ + cash_dividends = adjustment_reader.\ get_dividends_with_ex_date(held_sids, next_trading_day, asset_finder) - stock_dividends = self._adjustment_reader.\ + stock_dividends = adjustment_reader.\ get_stock_dividends_with_ex_date(held_sids, next_trading_day, asset_finder) @@ -310,7 +301,7 @@ class PerformanceTracker(object): self.cumulative_performance.handle_dividends_paid(net_cash_payment) self.todays_performance.handle_dividends_paid(net_cash_payment) - def handle_minute_close(self, dt): + def handle_minute_close(self, dt, data_portal): """ Handles the close of the given minute. This includes handling market-close functions if the given minute is the end of the market @@ -327,7 +318,7 @@ class PerformanceTracker(object): A tuple of the minute perf packet and daily perf packet. If the market day has not ended, the daily perf packet is None. """ - self.position_tracker.sync_last_sale_prices(dt, False) + self.position_tracker.sync_last_sale_prices(dt, False, data_portal) self.update_performance() todays_date = normalize_date(dt) account = self.get_account(False) @@ -346,16 +337,18 @@ class PerformanceTracker(object): # if this is the close, update dividends for the next day. # Return the performance tuple if dt == self.market_close: - return minute_packet, self._handle_market_close(todays_date) + return minute_packet, self._handle_market_close( + todays_date, data_portal._adjustment_reader, + ) else: return minute_packet, None - def handle_market_close_daily(self, dt): + def handle_market_close_daily(self, dt, data_portal): """ Function called after handle_data when running with daily emission rate. """ - self.position_tracker.sync_last_sale_prices(dt, False) + self.position_tracker.sync_last_sale_prices(dt, False, data_portal) self.update_performance() completed_date = self.day account = self.get_account(False) @@ -368,11 +361,12 @@ class PerformanceTracker(object): benchmark_value, account.leverage) - daily_packet = self._handle_market_close(completed_date) - + daily_packet = self._handle_market_close( + completed_date, data_portal._adjustment_reader, + ) return daily_packet - def _handle_market_close(self, completed_date): + def _handle_market_close(self, completed_date, adjustment_reader): # increment the day counter before we move markers forward. self.day_count += 1.0 @@ -406,7 +400,8 @@ class PerformanceTracker(object): return daily_update # Check for any dividends, then return the daily perf packet - self.check_upcoming_dividends(next_trading_day=next_trading_day) + self.check_upcoming_dividends(next_trading_day=next_trading_day, + adjustment_reader=adjustment_reader) return daily_update def handle_simulation_end(self): diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index b9200a5b..33a9c432 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -97,20 +97,9 @@ class AlgorithmSimulator(object): Main generator work loop. """ algo = self.algo - algo.data_portal = self.data_portal - handle_data = algo.event_manager.handle_data - current_data = self.current_data - data_portal = self.data_portal - - # can't cache a pointer to algo.perf_tracker because we're not - # guaranteed that the algo doesn't swap out perf trackers during - # its lifetime. - # likewise, we can't cache a pointer to the blotter. - - algo.perf_tracker.position_tracker.data_portal = data_portal - - def every_bar(dt_to_use): + def every_bar(dt_to_use, current_data=self.current_data, + handle_data=algo.event_manager.handle_data): # called every tick (minute or day). self.simulation_dt = dt_to_use @@ -152,7 +141,8 @@ class AlgorithmSimulator(object): self.algo.account_needs_update = True self.algo.performance_needs_update = True - def once_a_day(midnight_dt): + def once_a_day(midnight_dt, current_data=self.current_data, + data_portal=self.data_portal): # Get the positions before updating the date so that prices are # fetched for trading close instead of midnight positions = algo.perf_tracker.position_tracker.positions @@ -183,11 +173,15 @@ class AlgorithmSimulator(object): # call before trading start algo.before_trading_start(current_data) - def handle_benchmark(date): + def handle_benchmark(date, benchmark_source=self.benchmark_source): algo.perf_tracker.all_benchmark_returns[date] = \ - self.benchmark_source.get_value(date) + benchmark_source.get_value(date) + + def on_exit(): + self.benchmark_source = self.current_data = self.data_portal = None with ExitStack() as stack: + stack.callback(on_exit) stack.enter_context(self.processor) stack.enter_context(ZiplineAPI(self.algo)) @@ -245,8 +239,9 @@ class AlgorithmSimulator(object): assets_to_clear = \ [asset for asset in position_assets if past_auto_close_date(asset)] perf_tracker = algo.perf_tracker + data_portal = self.data_portal for asset in assets_to_clear: - perf_tracker.process_close_position(asset, dt) + perf_tracker.process_close_position(asset, dt, data_portal) # Remove open orders for any sids that have reached their # auto_close_date. @@ -257,23 +252,25 @@ class AlgorithmSimulator(object): for asset in assets_to_cancel: blotter.cancel_all_orders_for_asset(asset) - @staticmethod - def _get_daily_message(dt, algo, perf_tracker): + def _get_daily_message(self, dt, algo, perf_tracker): """ Get a perf message for the given datetime. """ - perf_message = perf_tracker.handle_market_close_daily(dt) + perf_message = perf_tracker.handle_market_close_daily( + dt, self.data_portal, + ) perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars return perf_message - @staticmethod - def _get_minute_message(dt, algo, perf_tracker): + def _get_minute_message(self, dt, algo, perf_tracker): """ Get a perf message for the given datetime. """ rvars = algo.recorded_vars - minute_message, daily_message = perf_tracker.handle_minute_close(dt) + minute_message, daily_message = perf_tracker.handle_minute_close( + dt, self.data_portal, + ) minute_message['minute_perf']['recorded_vars'] = rvars if daily_message: diff --git a/zipline/protocol.py b/zipline/protocol.py index da0d2397..9db94914 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -16,9 +16,8 @@ import pandas as pd from .utils.enum import enum -from zipline._protocol import BarData as _BarData +from zipline._protocol import BarData # noqa -BarData = _BarData # Datasource type should completely determine the other fields of a # message with its type.