MAINT: Fix warnings from numpy on NaT comparison.

This commit is contained in:
Scott Sanderson
2016-07-20 18:29:51 -04:00
parent 9bd6cab115
commit af5f4be17c
3 changed files with 39 additions and 14 deletions
+1 -10
View File
@@ -18,7 +18,7 @@ from numpy import apply_along_axis, float64, isnan, nan
from scipy.stats import rankdata
from zipline.utils.numpy_utils import (
is_float,
is_missing,
float64_dtype,
int64_dtype,
datetime64ns_dtype,
@@ -28,15 +28,6 @@ from zipline.utils.numpy_utils import (
import_array()
cpdef is_missing(ndarray data, object missing_value):
"""
Generic is_missing function that handles quirks with NaN.
"""
if is_float(data) and isnan(missing_value):
return isnan(data)
return (data == missing_value)
def rankdata_1d_descending(ndarray data, str method):
"""
1D descending version of scipy.stats.rankdata.
+13 -1
View File
@@ -49,7 +49,7 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader
from zipline.utils import security_list
from zipline.utils.calendars import get_calendar
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.numpy_utils import as_column
from zipline.utils.numpy_utils import as_column, isnat
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
from zipline.utils.sentinel import sentinel
@@ -394,6 +394,18 @@ def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
# ...then check the actual values as well.
x = x.as_string_array()
y = y.as_string_array()
elif x.dtype.kind in 'mM':
x_isnat = isnat(x)
y_isnat = isnat(y)
assert_array_equal(
x_isnat,
y_isnat,
err_msg="NaTs not equal",
verbose=verbose,
)
# Fill NaTs with zero for comparison.
x = np.where(x_isnat, np.zeros_like(x), x)
y = np.where(x_isnat, np.zeros_like(x), x)
return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose)
+25 -3
View File
@@ -17,8 +17,8 @@ from numpy import (
empty,
flatnonzero,
hstack,
isnan,
nan,
timedelta64,
vectorize,
where
)
@@ -287,6 +287,28 @@ def rolling_window(array, length):
# Sentinel value that isn't NaT.
_notNaT = make_datetime64D(0)
iNaT = NaTns.view(int64_dtype)
assert iNaT == NaTD.view(int64_dtype), "iNaTns != iNaTD"
def isnat(obj):
"""
Check if a value is np.NaT.
"""
if obj.dtype.kind not in ('m', 'M'):
raise ValueError("%s is not a numpy datetime or timedelta")
return obj.view(int64_dtype) == iNaT
def is_missing(data, missing_value):
"""
Generic is_missing function that handles NaN and NaT.
"""
if is_float(data) and isnan(missing_value):
return isnan(data)
elif is_datetime(data) and isnat(missing_value):
return isnat(data)
return (data == missing_value)
def busday_count_mask_NaT(begindates, enddates, out=None):
@@ -304,8 +326,8 @@ def busday_count_mask_NaT(begindates, enddates, out=None):
if out is None:
out = empty(broadcast(begindates, enddates).shape, dtype=float)
beginmask = (begindates == NaTD)
endmask = (enddates == NaTD)
beginmask = isnat(begindates)
endmask = isnat(enddates)
out = busday_count(
# Temporarily fill in non-NaT values.