From b5438ac94ecc46c7cf478beb2acb80b65585050f Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Thu, 19 Jan 2017 16:18:26 -0500 Subject: [PATCH] ENH: add current_session property to BarData --- tests/test_bar_data.py | 45 +++++++++++++++++++++++++++++++++++++++++- zipline/_protocol.pyx | 7 +++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_bar_data.py b/tests/test_bar_data.py index e38b7aff..67523f17 100644 --- a/tests/test_bar_data.py +++ b/tests/test_bar_data.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import timedelta +from datetime import timedelta, time from itertools import chain from nose_parameterized import parameterized @@ -20,6 +20,7 @@ import numpy as np from numpy import nan from numpy.testing import assert_almost_equal import pandas as pd +from toolz import concat from zipline._protocol import handle_non_market_minutes @@ -40,6 +41,7 @@ from zipline.testing.fixtures import ( ZiplineTestCase, ) from zipline.utils.calendars import get_calendar +from zipline.utils.calendars.trading_calendar import days_at_time OHLC = ["open", "high", "low", "close"] OHLCP = OHLC + ["price"] @@ -204,6 +206,34 @@ class TestMinuteBarData(WithCreateBarData, cls.ASSETS = [cls.ASSET1, cls.ASSET2] + def test_current_session(self): + regular_minutes = self.trading_calendar.minutes_for_sessions_in_range( + self.equity_minute_bar_days[0], + self.equity_minute_bar_days[-1] + ) + + bts_minutes = days_at_time( + self.equity_minute_bar_days, + time(8, 45), + "US/Eastern" + ) + + # some other non-market-minute + three_oh_six_am_minutes = days_at_time( + self.equity_minute_bar_days, + time(3, 06), + "US/Eastern" + ) + + all_minutes = [regular_minutes, bts_minutes, three_oh_six_am_minutes] + for minute in list(concat(all_minutes)): + bar_data = self.create_bardata(lambda: minute) + + self.assertEqual( + self.trading_calendar.minute_to_session_label(minute), + bar_data.current_session + ) + def test_minute_before_assets_trading(self): # grab minutes that include the day before the asset start minutes = self.trading_calendar.minutes_for_session( @@ -919,6 +949,19 @@ class TestDailyBarData(WithCreateBarData, session_label )[1] + def test_current_session(self): + for session in self.trading_calendar.sessions_in_range( + self.equity_daily_bar_days[0], + self.equity_daily_bar_days[-1] + ): + bar_data = self.create_bardata( + simulation_dt_func=lambda: self.get_last_minute_of_session( + session + ) + ) + + self.assertEqual(session, bar_data.current_session) + def test_day_before_assets_trading(self): # use the day before self.bcolz_daily_bar_days[0] minute = self.get_last_minute_of_session( diff --git a/zipline/_protocol.pyx b/zipline/_protocol.pyx index 0967f562..995b2153 100644 --- a/zipline/_protocol.pyx +++ b/zipline/_protocol.pyx @@ -747,6 +747,13 @@ cdef class BarData: def __set__(self, val): self._adjust_minutes = val + property current_session: + def __get__(self): + return self._trading_calendar.minute_to_session_label( + self.simulation_dt_func(), + direction="next" + ) + ################# # OLD API SUPPORT #################