mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 07:35:55 +08:00
ENH: Add fast "vectorized" minute_to_session_label for DatetimeIndex
The new TradingCalendar method is called `minute_index_to_session_labels`. It takes a DatetimeIndex of in-order market minutes and returns a DatetimeIndex of the corresponding sessions. The new method is approximately 100x faster than mapping `minute_to_session_label` over a large DatetimeIndex.
This commit is contained in:
@@ -432,6 +432,22 @@ class ExchangeCalendarTestBase(object):
|
||||
direction="none"
|
||||
)
|
||||
|
||||
@parameterized.expand([
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(2, 1),
|
||||
])
|
||||
def test_minute_index_to_session_labels(self, interval, offset):
|
||||
minutes = self.calendar.minutes_for_sessions_in_range('2011-01-04',
|
||||
'2011-04-04')
|
||||
minutes = minutes[range(offset, len(minutes), interval)]
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
np.array(minutes.map(self.calendar.minute_to_session_label),
|
||||
dtype='datetime64[ns]'),
|
||||
self.calendar.minute_index_to_session_labels(minutes)
|
||||
)
|
||||
|
||||
def test_next_prev_session(self):
|
||||
session_labels = self.answers.index[1:-2]
|
||||
max_idx = len(session_labels) - 1
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from numpy cimport ndarray, long_t
|
||||
from numpy import searchsorted
|
||||
from numpy cimport ndarray, int64_t
|
||||
from numpy import empty, searchsorted, int64
|
||||
cimport cython
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def next_divider_idx(ndarray[long_t, ndim=1] dividers, long_t minute_val):
|
||||
cpdef int next_divider_idx(ndarray[int64_t, ndim=1] dividers, int64_t minute_val):
|
||||
cdef int divider_idx
|
||||
cdef long target
|
||||
cdef int64_t target
|
||||
|
||||
divider_idx = searchsorted(dividers, minute_val, side="right")
|
||||
target = dividers[divider_idx]
|
||||
@@ -20,8 +20,8 @@ def next_divider_idx(ndarray[long_t, ndim=1] dividers, long_t minute_val):
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def previous_divider_idx(ndarray[long_t, ndim=1] dividers,
|
||||
long_t minute_val):
|
||||
def previous_divider_idx(ndarray[int64_t, ndim=1] dividers,
|
||||
int64_t minute_val):
|
||||
cdef int divider_idx
|
||||
|
||||
divider_idx = searchsorted(dividers, minute_val)
|
||||
@@ -31,9 +31,9 @@ def previous_divider_idx(ndarray[long_t, ndim=1] dividers,
|
||||
|
||||
return divider_idx - 1
|
||||
|
||||
def is_open(ndarray[long_t, ndim=1] opens,
|
||||
ndarray[long_t, ndim=1] closes,
|
||||
long_t minute_val):
|
||||
def is_open(ndarray[int64_t, ndim=1] opens,
|
||||
ndarray[int64_t, ndim=1] closes,
|
||||
int64_t minute_val):
|
||||
cdef open_idx, close_idx
|
||||
|
||||
open_idx = searchsorted(opens, minute_val)
|
||||
@@ -51,3 +51,24 @@ def is_open(ndarray[long_t, ndim=1] opens,
|
||||
# this can happen if we're outside the schedule's range (like
|
||||
# after the last close)
|
||||
return False
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
def minutes_to_session_labels(ndarray[int64_t, ndim=1] minutes,
|
||||
minute_to_session_label,
|
||||
ndarray[int64_t, ndim=1] closes):
|
||||
cdef int current_idx, next_idx, close_idx
|
||||
current_idx = next_idx = close_idx = 0
|
||||
|
||||
cdef ndarray[int64_t, ndim=1] results = empty(len(minutes), dtype=int64)
|
||||
|
||||
while current_idx < len(minutes):
|
||||
close_idx += searchsorted(closes[close_idx:],
|
||||
minutes[current_idx], side="right")
|
||||
next_idx += next_divider_idx(minutes[next_idx:], closes[close_idx])
|
||||
results[current_idx:next_idx] = minute_to_session_label(
|
||||
minutes[current_idx]
|
||||
)
|
||||
current_idx = next_idx
|
||||
|
||||
return results
|
||||
|
||||
@@ -29,7 +29,13 @@ from pandas.tseries.offsets import CustomBusinessDay
|
||||
from zipline.utils.calendars._calendar_helpers import (
|
||||
next_divider_idx,
|
||||
previous_divider_idx,
|
||||
is_open
|
||||
is_open,
|
||||
minutes_to_session_labels,
|
||||
)
|
||||
from zipline.utils.input_validation import (
|
||||
attrgetter,
|
||||
coerce,
|
||||
preprocess,
|
||||
)
|
||||
from zipline.utils.memoize import remember_last, lazyval
|
||||
|
||||
@@ -659,13 +665,14 @@ class TradingCalendar(with_metaclass(ABCMeta)):
|
||||
|
||||
return DatetimeIndex(all_minutes).tz_localize("UTC")
|
||||
|
||||
@preprocess(dt=coerce(pd.Timestamp, attrgetter('value')))
|
||||
def minute_to_session_label(self, dt, direction="next"):
|
||||
"""
|
||||
Given a minute, get the label of its containing session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt : pd.Timestamp
|
||||
dt : pd.Timestamp or nanosecond offset
|
||||
The dt for which to get the containing session.
|
||||
|
||||
direction: str
|
||||
@@ -684,17 +691,17 @@ class TradingCalendar(with_metaclass(ABCMeta)):
|
||||
The label of the containing session.
|
||||
"""
|
||||
|
||||
idx = searchsorted(self.market_closes_nanos, dt.value)
|
||||
idx = searchsorted(self.market_closes_nanos, dt)
|
||||
current_or_next_session = self.schedule.index[idx]
|
||||
|
||||
if direction == "previous":
|
||||
if not is_open(self.market_opens_nanos, self.market_closes_nanos,
|
||||
dt.value):
|
||||
dt):
|
||||
# if the exchange is closed, use the previous session
|
||||
return self.schedule.index[idx - 1]
|
||||
elif direction == "none":
|
||||
if not is_open(self.market_opens_nanos, self.market_closes_nanos,
|
||||
dt.value):
|
||||
dt):
|
||||
# if the exchange is closed, blow up
|
||||
raise ValueError("The given dt is not an exchange minute!")
|
||||
elif direction != "next":
|
||||
@@ -704,6 +711,30 @@ class TradingCalendar(with_metaclass(ABCMeta)):
|
||||
|
||||
return current_or_next_session
|
||||
|
||||
def minute_index_to_session_labels(self, index):
|
||||
"""
|
||||
Given a sorted DatetimeIndex of market minutes, return a
|
||||
DatetimeIndex of the corresponding session labels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index: pd.DatetimeIndex or pd.Series
|
||||
The ordered list of market minutes we want session labels for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DatetimeIndex (UTC)
|
||||
The list of session labels corresponding to the given minutes.
|
||||
"""
|
||||
def minute_to_session_label_nanos(dt_nanos):
|
||||
return self.minute_to_session_label(dt_nanos).value
|
||||
|
||||
return DatetimeIndex(minutes_to_session_labels(
|
||||
index.values.astype(np.int64),
|
||||
minute_to_session_label_nanos,
|
||||
self.market_closes_nanos,
|
||||
).astype('datetime64[ns]'), tz='UTC')
|
||||
|
||||
def _special_dates(self, calendars, ad_hoc_dates, start_date, end_date):
|
||||
"""
|
||||
Union an iterable of pairs of the form (time, calendar)
|
||||
|
||||
Reference in New Issue
Block a user