mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 10:43:09 +08:00
ENH: Add rolling_window to numpy_utils.py.
This commit is contained in:
@@ -148,6 +148,75 @@ def repeat_last_axis(array, count):
|
||||
return as_strided(array, array.shape + (count,), array.strides + (0,))
|
||||
|
||||
|
||||
def rolling_window(array, length):
|
||||
"""
|
||||
Restride an array of shape
|
||||
|
||||
(X_0, ... X_N)
|
||||
|
||||
into an array of shape
|
||||
|
||||
(length, X_0 - length + 1, ... X_N)
|
||||
|
||||
where each slice at index i along the first axis is equivalent to
|
||||
|
||||
result[i] = array[length * i:length * (i + 1)]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The base array.
|
||||
length : int
|
||||
Length of the synthetic first axis to generate.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : np.ndarray
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from numpy import arange
|
||||
>>> a = arange(25).reshape(5, 5)
|
||||
>>> a
|
||||
array([[ 0, 1, 2, 3, 4],
|
||||
[ 5, 6, 7, 8, 9],
|
||||
[10, 11, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19],
|
||||
[20, 21, 22, 23, 24]])
|
||||
|
||||
>>> rolling_window(a, 2)
|
||||
array([[[ 0, 1, 2, 3, 4],
|
||||
[ 5, 6, 7, 8, 9]],
|
||||
|
||||
[[ 5, 6, 7, 8, 9],
|
||||
[10, 11, 12, 13, 14]],
|
||||
|
||||
[[10, 11, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]],
|
||||
|
||||
[[15, 16, 17, 18, 19],
|
||||
[20, 21, 22, 23, 24]]])
|
||||
"""
|
||||
orig_shape = array.shape
|
||||
if not orig_shape:
|
||||
raise IndexError("Can't restride a scalar.")
|
||||
elif orig_shape[0] <= length:
|
||||
raise IndexError(
|
||||
"Can't restride array of shape {shape} with"
|
||||
" a window length of {len}".format(
|
||||
shape=orig_shape,
|
||||
len=length,
|
||||
)
|
||||
)
|
||||
|
||||
num_windows = (orig_shape[0] - length + 1)
|
||||
new_shape = (num_windows, length) + orig_shape[1:]
|
||||
|
||||
new_strides = (array.strides[0],) + array.strides
|
||||
|
||||
return as_strided(array, new_shape, new_strides)
|
||||
|
||||
|
||||
# Sentinel value that isn't NaT.
|
||||
_notNaT = make_datetime64D(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user