diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index bc2ca74a..cd93c1bc 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections import namedtuple import datetime from datetime import timedelta @@ -34,7 +35,6 @@ import pytz from pandas.io.common import PerformanceWarning from zipline import run_algorithm -from tests.warnings_catcher import WarningsCatcher from zipline import TradingAlgorithm from zipline.api import FixedSlippage from zipline.assets import Equity, Future @@ -1959,7 +1959,9 @@ def handle_data(context, data): pass """) - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) + algo = TradingAlgorithm( script=algocode, sim_params=sim_params, @@ -1967,18 +1969,19 @@ def handle_data(context, data): ) algo.run(self.data_portal) - self.assertEqual(len(w), 2) - for i, warning in enumerate(w): - self.assertIsInstance(warning.message, UserWarning) - self.assertEqual( - warning.message.args[0], - 'Got a time rule for the second positional argument ' - 'date_rule. You should use keyword argument ' - 'time_rule= when calling schedule_function without ' - 'specifying a date_rule' - ) - # The warnings come from line 13 and 14 in the algocode - self.assertEqual(warning.lineno, 13 + i) + self.assertEqual(len(w), 2) + + for i, warning in enumerate(w): + self.assertIsInstance(warning.message, UserWarning) + self.assertEqual( + warning.message.args[0], + 'Got a time rule for the second positional argument ' + 'date_rule. You should use keyword argument ' + 'time_rule= when calling schedule_function without ' + 'specifying a date_rule' + ) + # The warnings come from line 13 and 14 in the algocode + self.assertEqual(warning.lineno, 13 + i) self.assertEqual( algo.done_at_open, diff --git a/tests/test_api_shim.py b/tests/test_api_shim.py index 9d2652c4..634aa2f9 100644 --- a/tests/test_api_shim.py +++ b/tests/test_api_shim.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd from pandas.io.common import PerformanceWarning -from tests.warnings_catcher import WarningsCatcher from zipline import TradingAlgorithm from zipline.finance.trading import SimulationParameters from zipline.protocol import BarData @@ -292,7 +291,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): cease to be supported, we also want to assert that we're seeing a deprecation warning. """ - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) algo = self.create_algo(sid_accessor_algo) algo.run(self.data_portal) @@ -320,7 +320,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): We also want to assert that we warn that iterating over the assets in `data` is deprecated. """ - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) algo = self.create_algo(data_items_algo) algo.run(self.data_portal) @@ -344,7 +345,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): ) def test_iterate_data(self): - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) algo = self.create_algo(simple_algo) @@ -374,7 +376,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): ) def test_history(self): - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) sim_params = self.sim_params.create_new( @@ -415,7 +418,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): expected_vol_with_split) def test_simple_transforms(self): - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) sim_params = SimulationParameters( @@ -485,7 +489,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase): self.assertAlmostEqual(346, algo.returns) def test_manipulation(self): - with WarningsCatcher([PerformanceWarning]) as w: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore", PerformanceWarning) warnings.simplefilter("default", ZiplineDeprecationWarning) algo = self.create_algo(simple_algo) diff --git a/tests/test_clock.py b/tests/test_clock.py index 9f9b5a57..fa81b63c 100644 --- a/tests/test_clock.py +++ b/tests/test_clock.py @@ -1,7 +1,6 @@ from datetime import time from unittest import TestCase import pandas as pd -import numpy as np from zipline.gens.sim_engine import ( MinuteSimulationClock, SESSION_START, @@ -26,8 +25,8 @@ class TestClock(TestCase): ) trading_o_and_c = cls.nyse_calendar.schedule.ix[cls.sessions] - cls.opens = trading_o_and_c['market_open'].values.astype(np.int64) - cls.closes = trading_o_and_c['market_close'].values.astype(np.int64) + cls.opens = trading_o_and_c['market_open'] + cls.closes = trading_o_and_c['market_close'] def test_bts_before_session(self): clock = MinuteSimulationClock( diff --git a/tests/test_daily_history_aggregator.py b/tests/test_daily_history_aggregator.py index 2a234d31..628e4d59 100644 --- a/tests/test_daily_history_aggregator.py +++ b/tests/test_daily_history_aggregator.py @@ -120,6 +120,7 @@ class MinuteToDailyAggregationTestCase(WithBcolzEquityMinuteBarReader, self.equity_daily_aggregator = DailyHistoryAggregator( self.trading_calendar.schedule.market_open, self.bcolz_equity_minute_bar_reader, + self.trading_calendar ) @parameterized.expand([ diff --git a/tests/warnings_catcher.py b/tests/warnings_catcher.py deleted file mode 100644 index 0437afae..00000000 --- a/tests/warnings_catcher.py +++ /dev/null @@ -1,32 +0,0 @@ -from warnings import catch_warnings, WarningMessage - - -class WarningsCatcher(catch_warnings): - """ - Subclass of warnings.catch_warnings that takes a list of warning types to - ignore. - """ - def __init__(self, types_to_ignore=None): - super(WarningsCatcher, self).__init__(record=True) - - self._types_to_ignore = set(types_to_ignore or []) - - def __enter__(self): - if self._entered: - raise RuntimeError("Cannot enter %r twice" % self) - self._entered = True - self._filters = self._module.filters - self._module.filters = self._filters[:] - self._showwarning = self._module.showwarning - if self._record: - log = [] - - def showwarning(*args, **kwargs): - if args[1] in self._types_to_ignore: - return - log.append(WarningMessage(*args, **kwargs)) - - self._module.showwarning = showwarning - return log - else: - return None diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index c84264fc..6141d77e 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -463,11 +463,11 @@ cdef class BarData: cdef bool _can_trade_for_asset(self, asset, dt, adjusted_dt, data_portal): session_label = normalize_date(dt) # FIXME - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): # asset isn't alive return False - if not asset._asset_exchange_open(dt): + if not asset.is_exchange_open(dt): # exchange isn't open return False @@ -525,7 +525,7 @@ cdef class BarData: cdef bool _is_stale_for_asset(self, asset, dt, adjusted_dt, data_portal): session_label = normalize_date(dt) # FIXME - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): return False current_volume = data_portal.get_spot_value( diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 85928c7d..8827eff8 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -496,13 +496,11 @@ class TradingAlgorithm(object): """ trading_o_and_c = self.trading_calendar.schedule.ix[ self.sim_params.sessions] - market_closes = trading_o_and_c['market_close'].values.astype(np.int64) + market_closes = trading_o_and_c['market_close'] minutely_emission = False if self.sim_params.data_frequency == 'minute': - market_opens = trading_o_and_c['market_open'].values.astype( - np.int64 - ) + market_opens = trading_o_and_c['market_open'] minutely_emission = self.sim_params.emission_rate == "minute" else: diff --git a/zipline/assets/_assets.pyx b/zipline/assets/_assets.pyx index 311418f2..1fe11e1f 100644 --- a/zipline/assets/_assets.pyx +++ b/zipline/assets/_assets.pyx @@ -101,8 +101,6 @@ cdef class Asset: self.first_traded = first_traded self.auto_close_date = auto_close_date - - def __int__(self): return self.sid @@ -201,7 +199,7 @@ cdef class Asset: """ return cls(**dict_) - def _is_alive_for_session(self, session_label): + def is_alive_for_session(self, session_label): """ Returns whether the asset is alive at the given dt. @@ -222,7 +220,7 @@ cdef class Asset: return ref_start <= session_label.value <= ref_end - def _asset_exchange_open(self, dt_minute): + def is_exchange_open(self, dt_minute): """ Parameters ---------- @@ -233,10 +231,10 @@ cdef class Asset: ------- boolean: whether the asset's exchange is open at the given minute. """ - calendar = self._exchange_trading_calendar_for_asset() + calendar = self.exchange_trading_calendar() return calendar.is_open_on_minute(dt_minute) - def _exchange_trading_calendar_for_asset(self): + def exchange_trading_calendar(self): """ Get the calendar for this asset's exchange. diff --git a/zipline/data/daily_history_aggregator.py b/zipline/data/daily_history_aggregator.py index 7536e2d2..1e710d7b 100644 --- a/zipline/data/daily_history_aggregator.py +++ b/zipline/data/daily_history_aggregator.py @@ -99,7 +99,7 @@ class DailyHistoryAggregator(object): session_label = self._trading_calendar.minute_to_session_label(dt) for asset in assets: - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): opens.append(np.NaN) continue @@ -168,7 +168,7 @@ class DailyHistoryAggregator(object): session_label = self._trading_calendar.minute_to_session_label(dt) for asset in assets: - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): highs.append(np.NaN) continue @@ -237,7 +237,7 @@ class DailyHistoryAggregator(object): session_label = self._trading_calendar.minute_to_session_label(dt) for asset in assets: - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): lows.append(np.NaN) continue @@ -307,7 +307,7 @@ class DailyHistoryAggregator(object): session_label = self._trading_calendar.minute_to_session_label(dt) for asset in assets: - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): closes.append(np.NaN) continue @@ -367,7 +367,7 @@ class DailyHistoryAggregator(object): session_label = self._trading_calendar.minute_to_session_label(dt) for asset in assets: - if not asset._is_alive_for_session(session_label): + if not asset.is_alive_for_session(session_label): volumes.append(0) continue diff --git a/zipline/gens/sim_engine.pyx b/zipline/gens/sim_engine.pyx index f044fd37..aa3a9d51 100644 --- a/zipline/gens/sim_engine.pyx +++ b/zipline/gens/sim_engine.pyx @@ -30,11 +30,10 @@ cpdef enum: BEFORE_TRADING_START_BAR = 4 cdef class MinuteSimulationClock: - cdef object sessions cdef bool minute_emission - cdef np.int64_t[:] market_opens, market_closes - cdef object before_trading_start_minutes - cdef dict minutes_by_session, minutes_to_session + cdef np.int64_t[:] market_opens_nanos, market_closes_nanos, bts_nanos, \ + sessions_nanos + cdef dict minutes_by_session def __init__(self, sessions, @@ -43,71 +42,76 @@ cdef class MinuteSimulationClock: before_trading_start_minutes, minute_emission=False): self.minute_emission = minute_emission - self.market_opens = market_opens - self.market_closes = market_closes - self.sessions = sessions + + self.market_opens_nanos = market_opens.values.astype(np.int64) + self.market_closes_nanos = market_closes.values.astype(np.int64) + self.sessions_nanos = sessions.values.astype(np.int64) + self.bts_nanos = before_trading_start_minutes.values.astype(np.int64) + self.minutes_by_session = self.calc_minutes_by_session() - self.before_trading_start_minutes = before_trading_start_minutes - - @cython.boundscheck(False) - @cython.wraparound(False) - cdef np.ndarray[np.int64_t, ndim=1] market_minutes(self, np.intp_t i): - cdef np.int64_t[:] market_opens, market_closes - - market_opens = self.market_opens - market_closes = self.market_closes - - return np.arange(market_opens[i], - market_closes[i] + _nanos_in_minute, - _nanos_in_minute) - @cython.boundscheck(False) @cython.wraparound(False) cdef dict calc_minutes_by_session(self): cdef dict minutes_by_session cdef int session_idx - cdef object session + cdef np.int64_t session_nano + cdef np.ndarray[np.int64_t, ndim=1] minutes_nanos minutes_by_session = {} - for session_idx, session in enumerate(self.sessions): - minutes_by_session[session] = pd.to_datetime( - self.market_minutes(session_idx), utc=True, box=True) + for session_idx, session_nano in enumerate(self.sessions_nanos): + minutes_nanos = np.arange( + self.market_opens_nanos[session_idx], + self.market_closes_nanos[session_idx] + _nanos_in_minute, + _nanos_in_minute + ) + minutes_by_session[session_nano] = pd.to_datetime( + minutes_nanos, utc=True, box=True + ) return minutes_by_session def __iter__(self): minute_emission = self.minute_emission - for idx, session in enumerate(self.sessions): - yield session, SESSION_START + for idx, session_nano in enumerate(self.sessions_nanos): + yield pd.Timestamp(session_nano, tz='UTC'), SESSION_START - bts_minute = self.before_trading_start_minutes[idx] - regular_minutes = self.minutes_by_session[session] + bts_minute = pd.Timestamp(self.bts_nanos[idx], tz='UTC') + regular_minutes = self.minutes_by_session[session_nano] - # we have to search anew every session, because there is no - # guarantee that any two session start on the same minute - bts_idx = regular_minutes.searchsorted(bts_minute) - - if bts_idx == len(regular_minutes): - # before_trading_start is after the last close, so don't emit - # it - for minute in regular_minutes: - yield minute, BAR - if minute_emission: - yield minute, MINUTE_END + if bts_minute > regular_minutes[-1]: + # before_trading_start is after the last close, + # so don't emit it + for minute, evt in self._get_minutes_for_list( + regular_minutes, + minute_emission + ): + yield minute, evt else: + # we have to search anew every session, because there is no + # guarantee that any two session start on the same minute + bts_idx = regular_minutes.searchsorted(bts_minute) + # emit all the minutes before bts_minute - for minute in regular_minutes[0:bts_idx]: - yield minute, BAR - if minute_emission: - yield minute, MINUTE_END + for minute, evt in self._get_minutes_for_list( + regular_minutes[0:bts_idx], + minute_emission + ): + yield minute, evt yield bts_minute, BEFORE_TRADING_START_BAR # emit all the minutes after bts_minute - for minute in regular_minutes[bts_idx:]: - yield minute, BAR - if minute_emission: - yield minute, MINUTE_END + for minute, evt in self._get_minutes_for_list( + regular_minutes[bts_idx:], + minute_emission + ): + yield minute, evt yield regular_minutes[-1], SESSION_END + + def _get_minutes_for_list(self, minutes, minute_emission): + for minute in minutes: + yield minute, BAR + if minute_emission: + yield minute, MINUTE_END