mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 23:28:39 +08:00
MAINT: Fix warnings from numpy on NaT comparison.
This commit is contained in:
+1
-10
@@ -18,7 +18,7 @@ from numpy import apply_along_axis, float64, isnan, nan
|
||||
from scipy.stats import rankdata
|
||||
|
||||
from zipline.utils.numpy_utils import (
|
||||
is_float,
|
||||
is_missing,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
datetime64ns_dtype,
|
||||
@@ -28,15 +28,6 @@ from zipline.utils.numpy_utils import (
|
||||
import_array()
|
||||
|
||||
|
||||
cpdef is_missing(ndarray data, object missing_value):
|
||||
"""
|
||||
Generic is_missing function that handles quirks with NaN.
|
||||
"""
|
||||
if is_float(data) and isnan(missing_value):
|
||||
return isnan(data)
|
||||
return (data == missing_value)
|
||||
|
||||
|
||||
def rankdata_1d_descending(ndarray data, str method):
|
||||
"""
|
||||
1D descending version of scipy.stats.rankdata.
|
||||
|
||||
+13
-1
@@ -49,7 +49,7 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.numpy_utils import as_column
|
||||
from zipline.utils.numpy_utils import as_column, isnat
|
||||
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
@@ -394,6 +394,18 @@ def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
|
||||
# ...then check the actual values as well.
|
||||
x = x.as_string_array()
|
||||
y = y.as_string_array()
|
||||
elif x.dtype.kind in 'mM':
|
||||
x_isnat = isnat(x)
|
||||
y_isnat = isnat(y)
|
||||
assert_array_equal(
|
||||
x_isnat,
|
||||
y_isnat,
|
||||
err_msg="NaTs not equal",
|
||||
verbose=verbose,
|
||||
)
|
||||
# Fill NaTs with zero for comparison.
|
||||
x = np.where(x_isnat, np.zeros_like(x), x)
|
||||
y = np.where(x_isnat, np.zeros_like(x), x)
|
||||
|
||||
return assert_array_equal(x, y, err_msg=err_msg, verbose=verbose)
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from numpy import (
|
||||
empty,
|
||||
flatnonzero,
|
||||
hstack,
|
||||
isnan,
|
||||
nan,
|
||||
timedelta64,
|
||||
vectorize,
|
||||
where
|
||||
)
|
||||
@@ -287,6 +287,28 @@ def rolling_window(array, length):
|
||||
|
||||
# Sentinel value that isn't NaT.
|
||||
_notNaT = make_datetime64D(0)
|
||||
iNaT = NaTns.view(int64_dtype)
|
||||
assert iNaT == NaTD.view(int64_dtype), "iNaTns != iNaTD"
|
||||
|
||||
|
||||
def isnat(obj):
|
||||
"""
|
||||
Check if a value is np.NaT.
|
||||
"""
|
||||
if obj.dtype.kind not in ('m', 'M'):
|
||||
raise ValueError("%s is not a numpy datetime or timedelta")
|
||||
return obj.view(int64_dtype) == iNaT
|
||||
|
||||
|
||||
def is_missing(data, missing_value):
|
||||
"""
|
||||
Generic is_missing function that handles NaN and NaT.
|
||||
"""
|
||||
if is_float(data) and isnan(missing_value):
|
||||
return isnan(data)
|
||||
elif is_datetime(data) and isnat(missing_value):
|
||||
return isnat(data)
|
||||
return (data == missing_value)
|
||||
|
||||
|
||||
def busday_count_mask_NaT(begindates, enddates, out=None):
|
||||
@@ -304,8 +326,8 @@ def busday_count_mask_NaT(begindates, enddates, out=None):
|
||||
if out is None:
|
||||
out = empty(broadcast(begindates, enddates).shape, dtype=float)
|
||||
|
||||
beginmask = (begindates == NaTD)
|
||||
endmask = (enddates == NaTD)
|
||||
beginmask = isnat(begindates)
|
||||
endmask = isnat(enddates)
|
||||
|
||||
out = busday_count(
|
||||
# Temporarily fill in non-NaT values.
|
||||
|
||||
Reference in New Issue
Block a user