diff --git a/zipline/lib/rank.pyx b/zipline/lib/rank.pyx index a23fa231..022b39b8 100644 --- a/zipline/lib/rank.pyx +++ b/zipline/lib/rank.pyx @@ -18,7 +18,7 @@ from numpy import apply_along_axis, float64, isnan, nan from scipy.stats import rankdata from zipline.utils.numpy_utils import ( - is_float, + is_missing, float64_dtype, int64_dtype, datetime64ns_dtype, @@ -28,15 +28,6 @@ from zipline.utils.numpy_utils import ( import_array() -cpdef is_missing(ndarray data, object missing_value): - """ - Generic is_missing function that handles quirks with NaN. - """ - if is_float(data) and isnan(missing_value): - return isnan(data) - return (data == missing_value) - - def rankdata_1d_descending(ndarray data, str method): """ 1D descending version of scipy.stats.rankdata. diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 3517a2e3..f9832767 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -49,7 +49,7 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader from zipline.utils import security_list from zipline.utils.calendars import get_calendar from zipline.utils.input_validation import expect_dimensions -from zipline.utils.numpy_utils import as_column +from zipline.utils.numpy_utils import as_column, isnat from zipline.utils.pandas_utils import timedelta_to_integral_seconds from zipline.utils.sentinel import sentinel @@ -394,6 +394,18 @@ def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True): # ...then check the actual values as well. x = x.as_string_array() y = y.as_string_array() + elif x.dtype.kind in 'mM': + x_isnat = isnat(x) + y_isnat = isnat(y) + assert_array_equal( + x_isnat, + y_isnat, + err_msg="NaTs not equal", + verbose=verbose, + ) + # Fill NaTs with zero for comparison. + x = np.where(x_isnat, np.zeros_like(x), x) + y = np.where(x_isnat, np.zeros_like(x), x) return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose) diff --git a/zipline/utils/numpy_utils.py b/zipline/utils/numpy_utils.py index e8e2be52..77371da6 100644 --- a/zipline/utils/numpy_utils.py +++ b/zipline/utils/numpy_utils.py @@ -17,8 +17,8 @@ from numpy import ( empty, flatnonzero, hstack, + isnan, nan, - timedelta64, vectorize, where ) @@ -287,6 +287,28 @@ def rolling_window(array, length): # Sentinel value that isn't NaT. _notNaT = make_datetime64D(0) +iNaT = NaTns.view(int64_dtype) +assert iNaT == NaTD.view(int64_dtype), "iNaTns != iNaTD" + + +def isnat(obj): + """ + Check if a value is np.NaT. + """ + if obj.dtype.kind not in ('m', 'M'): + raise ValueError("%s is not a numpy datetime or timedelta") + return obj.view(int64_dtype) == iNaT + + +def is_missing(data, missing_value): + """ + Generic is_missing function that handles NaN and NaT. + """ + if is_float(data) and isnan(missing_value): + return isnan(data) + elif is_datetime(data) and isnat(missing_value): + return isnat(data) + return (data == missing_value) def busday_count_mask_NaT(begindates, enddates, out=None): @@ -304,8 +326,8 @@ def busday_count_mask_NaT(begindates, enddates, out=None): if out is None: out = empty(broadcast(begindates, enddates).shape, dtype=float) - beginmask = (begindates == NaTD) - endmask = (enddates == NaTD) + beginmask = isnat(begindates) + endmask = isnat(enddates) out = busday_count( # Temporarily fill in non-NaT values.