Allow view_as_window to take a tuple step, and update tests

This commit is contained in:
Steven Silvester
2015-06-07 06:13:36 -05:00
parent da8e2c7c69
commit 3f8e94ff0c
2 changed files with 55 additions and 25 deletions
+27 -16
View File
@@ -1,9 +1,10 @@
__all__ = ['view_as_blocks', 'view_as_windows']
import numbers
import numpy as np
from numpy.lib.stride_tricks import as_strided
from warnings import warn
__all__ = ['view_as_blocks', 'view_as_windows']
def view_as_blocks(arr_in, block_shape):
"""Block view of the input n-dimensional array (using re-striding).
@@ -112,13 +113,14 @@ def view_as_windows(arr_in, window_shape, step=1):
----------
arr_in : ndarray
N-d input array.
window_shape : tuple
window_shape : integer or tuple of length arr_in.ndmi
Defines the shape of the elementary n-dimensional orthotope
(better know as hyperrectangle [1]_) of the rolling window view.
step : int, optional
Number of elements to skip when moving the window forward (by
default, move forward by one). The value must be equal or larger
than one.
If an integer is given, the shape will be a hyperrectangle of
sidelength given by its value.
step : integer or tuple of length arr_in.ndim
Indicates step size at which extraction shall be performed.
If integer is given, then the step is uniform in all dimensions.
Returns
-------
@@ -215,13 +217,18 @@ def view_as_windows(arr_in, window_shape, step=1):
# -- basic checks on arguments
if not isinstance(arr_in, np.ndarray):
raise TypeError("`arr_in` must be a numpy ndarray")
if not isinstance(window_shape, tuple):
raise TypeError("`window_shape` must be a tuple")
if isinstance(window_shape, numbers.Number):
window_shape = tuple([window_shape] * arr_in.ndim)
if not (len(window_shape) == arr_in.ndim):
raise ValueError("`window_shape` is incompatible with `arr_in.shape`")
if step < 1:
raise ValueError("`step` must be >= 1")
if isinstance(step, numbers.Number):
if step < 1:
raise ValueError("`step` must be >= 1")
step = tuple([step] * arr_in.ndim)
if not (len(step) == arr_in.ndim):
raise ValueError("`step` is incompatible with `arr_in.shape`")
arr_shape = np.array(arr_in.shape)
window_shape = np.array(window_shape, dtype=arr_shape.dtype)
@@ -239,12 +246,16 @@ def view_as_windows(arr_in, window_shape, step=1):
arr_in = np.ascontiguousarray(arr_in)
new_shape = tuple((arr_shape - window_shape) // step + 1) + \
tuple(window_shape)
slices = [slice(None, None, st) for st in step]
window_strides = np.array(arr_in.strides)
arr_strides = np.array(arr_in.strides)
new_strides = np.concatenate((arr_strides * step, arr_strides))
indexing_strides = arr_in[slices].strides
arr_out = as_strided(arr_in, shape=new_shape, strides=new_strides)
win_indices_shape = (((np.array(arr_in.shape) - np.array(window_shape))
// np.array(step)) + 1)
new_shape = tuple(list(win_indices_shape) + list(window_shape))
strides = tuple(list(indexing_strides) + list(window_strides))
arr_out = as_strided(arr_in, shape=new_shape, strides=strides)
return arr_out
+28 -9
View File
@@ -47,7 +47,7 @@ def test_view_as_blocks_2D_array():
A = np.arange(4 * 4).reshape(4, 4)
B = view_as_blocks(A, (2, 2))
assert_equal(B[0, 1], np.array([[2, 3],
[6, 7]]))
[6, 7]]))
assert_equal(B[1, 0, 1, 1], 13)
@@ -67,12 +67,6 @@ def test_view_as_windows_input_not_array():
view_as_windows(A, (2,))
@raises(TypeError)
def test_view_as_windows_window_not_tuple():
A = np.arange(10)
view_as_windows(A, [2])
@raises(ValueError)
def test_view_as_windows_wrong_window_dimension():
A = np.arange(10)
@@ -96,6 +90,7 @@ def test_view_as_windows_step_below_one():
A = np.arange(10)
view_as_windows(A, (11,), step=0.9)
def test_view_as_windows_1D():
A = np.arange(10)
window_shape = (3,)
@@ -135,7 +130,7 @@ def test_view_as_windows_2D():
def test_view_as_windows_with_skip():
A = np.arange(20).reshape((5, 4))
B = view_as_windows(A, (2, 2), step=2)
B = view_as_windows(A, 2, step=2)
assert_equal(B, [[[[0, 1],
[4, 5]],
[[2, 3],
@@ -145,7 +140,7 @@ def test_view_as_windows_with_skip():
[[10, 11],
[14, 15]]]])
C = view_as_windows(A, (2, 2), step=4)
C = view_as_windows(A, 2, step=4)
assert_equal(C.shape, (1, 1, 2, 2))
@@ -157,5 +152,29 @@ def test_views_non_contiguous():
assert_warns(RuntimeWarning, view_as_windows, A, (2, 2))
def test_view_as_windows_step_tuple():
A = np.arange(24).reshape((6, 4))
B = view_as_windows(A, (3, 2), step=3)
assert B.shape == (2, 1, 3, 2)
assert B.size != A.size
C = view_as_windows(A, (3, 2), step=(3, 2))
assert C.shape == (2, 2, 3, 2)
assert C.size == A.size
assert_equal(C, [[[[0, 1],
[4, 5],
[8, 9]],
[[2, 3],
[6, 7],
[10, 11]]],
[[[12, 13],
[16, 17],
[20, 21]],
[[14, 15],
[18, 19],
[22, 23]]]])
if __name__ == '__main__':
np.testing.run_module_suite()