mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 21:24:16 +08:00
254 lines
5.8 KiB
Python
254 lines
5.8 KiB
Python
"""
|
|
Utilities for working with numpy arrays.
|
|
"""
|
|
from numpy import (
|
|
broadcast,
|
|
busday_count,
|
|
datetime64,
|
|
dtype,
|
|
empty,
|
|
nan,
|
|
where
|
|
)
|
|
from numpy.lib.stride_tricks import as_strided
|
|
from toolz import flip
|
|
|
|
uint8_dtype = dtype('uint8')
|
|
bool_dtype = dtype('bool')
|
|
|
|
int64_dtype = dtype('int64')
|
|
|
|
float32_dtype = dtype('float32')
|
|
float64_dtype = dtype('float64')
|
|
|
|
complex128_dtype = dtype('complex128')
|
|
|
|
datetime64D_dtype = dtype('datetime64[D]')
|
|
datetime64ns_dtype = dtype('datetime64[ns]')
|
|
|
|
make_datetime64ns = flip(datetime64, 'ns')
|
|
make_datetime64D = flip(datetime64, 'D')
|
|
|
|
NaTmap = {
|
|
dtype('datetime64[%s]' % unit): datetime64('NaT', unit)
|
|
for unit in ('ns', 'us', 'ms', 's', 'm', 'D')
|
|
}
|
|
NaT_for_dtype = NaTmap.__getitem__
|
|
NaTns = NaT_for_dtype(datetime64ns_dtype)
|
|
NaTD = NaT_for_dtype(datetime64D_dtype)
|
|
|
|
|
|
_FILLVALUE_DEFAULTS = {
|
|
bool_dtype: False,
|
|
float32_dtype: nan,
|
|
float64_dtype: nan,
|
|
datetime64ns_dtype: NaTns,
|
|
}
|
|
|
|
|
|
class NoDefaultMissingValue(Exception):
|
|
pass
|
|
|
|
|
|
def default_missing_value_for_dtype(dtype):
|
|
"""
|
|
Get the default fill value for `dtype`.
|
|
"""
|
|
try:
|
|
return _FILLVALUE_DEFAULTS[dtype]
|
|
except KeyError:
|
|
raise NoDefaultMissingValue(
|
|
"No default value registered for dtype %s." % dtype
|
|
)
|
|
|
|
|
|
def repeat_first_axis(array, count):
|
|
"""
|
|
Restride `array` to repeat `count` times along the first axis.
|
|
|
|
Parameters
|
|
----------
|
|
array : np.array
|
|
The array to restride.
|
|
count : int
|
|
Number of times to repeat `array`.
|
|
|
|
Returns
|
|
-------
|
|
result : array
|
|
Array of shape (count,) + array.shape, composed of `array` repeated
|
|
`count` times along the first axis.
|
|
|
|
Example
|
|
-------
|
|
>>> from numpy import arange
|
|
>>> a = arange(3); a
|
|
array([0, 1, 2])
|
|
>>> repeat_first_axis(a, 2)
|
|
array([[0, 1, 2],
|
|
[0, 1, 2]])
|
|
>>> repeat_first_axis(a, 4)
|
|
array([[0, 1, 2],
|
|
[0, 1, 2],
|
|
[0, 1, 2],
|
|
[0, 1, 2]])
|
|
|
|
Notes
|
|
----
|
|
The resulting array will share memory with `array`. If you need to assign
|
|
to the input or output, you should probably make a copy first.
|
|
|
|
See Also
|
|
--------
|
|
repeat_last_axis
|
|
"""
|
|
return as_strided(array, (count,) + array.shape, (0,) + array.strides)
|
|
|
|
|
|
def repeat_last_axis(array, count):
|
|
"""
|
|
Restride `array` to repeat `count` times along the last axis.
|
|
|
|
Parameters
|
|
----------
|
|
array : np.array
|
|
The array to restride.
|
|
count : int
|
|
Number of times to repeat `array`.
|
|
|
|
Returns
|
|
-------
|
|
result : array
|
|
Array of shape array.shape + (count,) composed of `array` repeated
|
|
`count` times along the last axis.
|
|
|
|
Example
|
|
-------
|
|
>>> from numpy import arange
|
|
>>> a = arange(3); a
|
|
array([0, 1, 2])
|
|
>>> repeat_last_axis(a, 2)
|
|
array([[0, 0],
|
|
[1, 1],
|
|
[2, 2]])
|
|
>>> repeat_last_axis(a, 4)
|
|
array([[0, 0, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[2, 2, 2, 2]])
|
|
|
|
Notes
|
|
----
|
|
The resulting array will share memory with `array`. If you need to assign
|
|
to the input or output, you should probably make a copy first.
|
|
|
|
See Also
|
|
--------
|
|
repeat_last_axis
|
|
"""
|
|
return as_strided(array, array.shape + (count,), array.strides + (0,))
|
|
|
|
|
|
def rolling_window(array, length):
|
|
"""
|
|
Restride an array of shape
|
|
|
|
(X_0, ... X_N)
|
|
|
|
into an array of shape
|
|
|
|
(length, X_0 - length + 1, ... X_N)
|
|
|
|
where each slice at index i along the first axis is equivalent to
|
|
|
|
result[i] = array[length * i:length * (i + 1)]
|
|
|
|
Parameters
|
|
----------
|
|
array : np.ndarray
|
|
The base array.
|
|
length : int
|
|
Length of the synthetic first axis to generate.
|
|
|
|
Returns
|
|
-------
|
|
out : np.ndarray
|
|
|
|
Example
|
|
-------
|
|
>>> from numpy import arange
|
|
>>> a = arange(25).reshape(5, 5)
|
|
>>> a
|
|
array([[ 0, 1, 2, 3, 4],
|
|
[ 5, 6, 7, 8, 9],
|
|
[10, 11, 12, 13, 14],
|
|
[15, 16, 17, 18, 19],
|
|
[20, 21, 22, 23, 24]])
|
|
|
|
>>> rolling_window(a, 2)
|
|
array([[[ 0, 1, 2, 3, 4],
|
|
[ 5, 6, 7, 8, 9]],
|
|
|
|
[[ 5, 6, 7, 8, 9],
|
|
[10, 11, 12, 13, 14]],
|
|
|
|
[[10, 11, 12, 13, 14],
|
|
[15, 16, 17, 18, 19]],
|
|
|
|
[[15, 16, 17, 18, 19],
|
|
[20, 21, 22, 23, 24]]])
|
|
"""
|
|
orig_shape = array.shape
|
|
if not orig_shape:
|
|
raise IndexError("Can't restride a scalar.")
|
|
elif orig_shape[0] <= length:
|
|
raise IndexError(
|
|
"Can't restride array of shape {shape} with"
|
|
" a window length of {len}".format(
|
|
shape=orig_shape,
|
|
len=length,
|
|
)
|
|
)
|
|
|
|
num_windows = (orig_shape[0] - length + 1)
|
|
new_shape = (num_windows, length) + orig_shape[1:]
|
|
|
|
new_strides = (array.strides[0],) + array.strides
|
|
|
|
return as_strided(array, new_shape, new_strides)
|
|
|
|
|
|
# Sentinel value that isn't NaT.
|
|
_notNaT = make_datetime64D(0)
|
|
|
|
|
|
def busday_count_mask_NaT(begindates,
|
|
enddates,
|
|
out=None):
|
|
"""
|
|
Simple of numpy.busday_count that returns `float` arrays rather than int
|
|
arrays, and handles `NaT`s by returning `NaN`s where the inputs were `NaT`.
|
|
|
|
Doesn't support custom weekdays or calendars, but probably should in the
|
|
future.
|
|
|
|
See Also
|
|
--------
|
|
np.busday_count
|
|
"""
|
|
if out is None:
|
|
out = empty(broadcast(begindates, enddates).shape, dtype=float)
|
|
|
|
beginmask = (begindates == NaTD)
|
|
endmask = (enddates == NaTD)
|
|
|
|
out = busday_count(
|
|
# Temporarily fill in non-NaT values.
|
|
where(beginmask, _notNaT, begindates),
|
|
where(endmask, _notNaT, enddates),
|
|
out=out,
|
|
)
|
|
|
|
# Fill in entries where either comparison was NaT with nan in the output.
|
|
out[beginmask | endmask] = nan
|
|
return out
|