mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 02:37:32 +08:00
BUG: Fix bad error handling in history loader.
Fixes a bug where we'd fail to raise an error if the start/end of a history window call don't aren't in the loader's calendar. We were started dropping this error after a previous change swapped out calls to `index.get_loc` with calls to `index.searchsorted` to avoid creating hash tables in pandas.
This commit is contained in:
@@ -34,6 +34,7 @@ from zipline.lib.adjustment import Float64Multiply, Float64Add
|
||||
from zipline.utils.cache import ExpiringCache
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
from zipline.utils.pandas_utils import find_in_sorted_index
|
||||
|
||||
|
||||
class HistoryCompatibleUSEquityAdjustmentReader(object):
|
||||
@@ -376,14 +377,10 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
size = len(dts)
|
||||
asset_windows = {}
|
||||
needed_assets = []
|
||||
cal = self._calendar
|
||||
|
||||
assets = self._asset_finder.retrieve_all(assets)
|
||||
|
||||
try:
|
||||
end_ix = self._calendar.searchsorted(end)
|
||||
except KeyError:
|
||||
raise KeyError("{0} not in calendar [{1}...{2}]".format(
|
||||
end, self._calendar[0], self._calendar[-1]))
|
||||
end_ix = find_in_sorted_index(cal, end)
|
||||
|
||||
for asset in assets:
|
||||
try:
|
||||
@@ -401,15 +398,9 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
asset_windows[asset] = window
|
||||
|
||||
if needed_assets:
|
||||
start = dts[0]
|
||||
|
||||
offset = 0
|
||||
try:
|
||||
start_ix = self._calendar.searchsorted(start)
|
||||
except KeyError:
|
||||
raise KeyError("{0} not in calendar [{1}...{2}]".format(
|
||||
start, self._calendar[0], self._calendar[-1]))
|
||||
cal = self._calendar
|
||||
start_ix = find_in_sorted_index(cal, dts[0])
|
||||
|
||||
prefetch_end_ix = min(end_ix + self._prefetch_length, len(cal) - 1)
|
||||
prefetch_end = cal[prefetch_end_ix]
|
||||
prefetch_dts = cal[start_ix:prefetch_end_ix + 1]
|
||||
|
||||
@@ -91,6 +91,39 @@ def mask_between_time(dts, start, end, include_start=True, include_end=True):
|
||||
)
|
||||
|
||||
|
||||
def find_in_sorted_index(dts, dt):
|
||||
"""
|
||||
Find the index of ``dt`` in ``dts``.
|
||||
|
||||
This function should be used instead of `dts.get_loc(dt)` if the index is
|
||||
large enough that we don't want to initialize a hash table in ``dts``. In
|
||||
particular, this should always be used on minutely trading calendars.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dts : pd.DatetimeIndex
|
||||
Index in which to look up ``dt``. **Must be sorted**.
|
||||
dt : pd.Timestamp
|
||||
``dt`` to be looked up.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ix : int
|
||||
Integer index such that dts[ix] == dt.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If dt is not in ``dts``.
|
||||
"""
|
||||
ix = dts.searchsorted(dt)
|
||||
if dts[ix] != dt:
|
||||
raise KeyError(
|
||||
"{0} is not in calendar [{1} ... {2}]".format(dt, dts[0], dts[-1])
|
||||
)
|
||||
return ix
|
||||
|
||||
|
||||
def nearest_unequal_elements(dts, dt):
|
||||
"""
|
||||
Find values in ``dts`` closest but not equal to ``dt``.
|
||||
|
||||
Reference in New Issue
Block a user