diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index bb0b4e69..55c76026 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -27,7 +27,7 @@ from collections import OrderedDict from zipline.data.loader import load_market_data -log = logbook.Logger('Transaction Simulator') +log = logbook.Logger('Trading') # The financial simulations in zipline depend on information @@ -76,7 +76,8 @@ class TradingEnvironment(object): load=None, bm_symbol='^GSPC', exchange_tz="US/Eastern", - max_date=None + max_date=None, + extra_dates=None ): self.prev_environment = self self.trading_day_map = OrderedDict() @@ -96,11 +97,21 @@ class TradingEnvironment(object): self.full_trading_day = datetime.timedelta(hours=6, minutes=30) self.exchange_tz = exchange_tz + bm = None for bm in self.benchmark_returns: if max_date and bm.date > max_date: break self.trading_day_map[bm.date] = bm + if bm and extra_dates: + last_day = next(reversed(self.trading_day_map)) + for extra_date in extra_dates: + extra_date = extra_date.replace(hour=0, minute=0, second=0, + microsecond=0) + if extra_date not in self.trading_day_map: + self.trading_day_map[extra_date] = \ + self.trading_day_map[last_day] + self.first_trading_day = next(self.trading_day_map.iterkeys()) self.last_trading_day = next(reversed(self.trading_day_map)) diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 0f67bb0d..ee1bf10b 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -162,7 +162,10 @@ class AlgorithmSimulator(object): tp = self.algo.perf_tracker.todays_performance tp.rollover() if mkt_close < self.algo.perf_tracker.last_close: - mkt_close = self.get_next_close(mkt_close) + _, mkt_close = \ + trading.environment.next_open_and_close( + mkt_close + ) self.algo.perf_tracker.handle_intraday_close() risk_message = self.algo.perf_tracker.handle_simulation_end() @@ -182,12 +185,6 @@ class AlgorithmSimulator(object): perf_message['minute_perf']['recorded_vars'] = rvars return perf_message - def get_next_close(self, mkt_close): - if mkt_close >= trading.environment.last_trading_day: - return self.sim_params.last_close - else: - return trading.environment.next_open_and_close(mkt_close)[1] - def update_universe(self, event): """ Update the universe with new event information.