diff --git a/zipline/utils/numpy_utils.py b/zipline/utils/numpy_utils.py index 56315cf7..db5f168a 100644 --- a/zipline/utils/numpy_utils.py +++ b/zipline/utils/numpy_utils.py @@ -148,6 +148,75 @@ def repeat_last_axis(array, count): return as_strided(array, array.shape + (count,), array.strides + (0,)) +def rolling_window(array, length): + """ + Restride an array of shape + + (X_0, ... X_N) + + into an array of shape + + (length, X_0 - length + 1, ... X_N) + + where each slice at index i along the first axis is equivalent to + + result[i] = array[length * i:length * (i + 1)] + + Parameters + ---------- + array : np.ndarray + The base array. + length : int + Length of the synthetic first axis to generate. + + Returns + ------- + out : np.ndarray + + Example + ------- + >>> from numpy import arange + >>> a = arange(25).reshape(5, 5) + >>> a + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + + >>> rolling_window(a, 2) + array([[[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9]], + + [[ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]], + + [[10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], + + [[15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]]) + """ + orig_shape = array.shape + if not orig_shape: + raise IndexError("Can't restride a scalar.") + elif orig_shape[0] <= length: + raise IndexError( + "Can't restride array of shape {shape} with" + " a window length of {len}".format( + shape=orig_shape, + len=length, + ) + ) + + num_windows = (orig_shape[0] - length + 1) + new_shape = (num_windows, length) + orig_shape[1:] + + new_strides = (array.strides[0],) + array.strides + + return as_strided(array, new_shape, new_strides) + + # Sentinel value that isn't NaT. _notNaT = make_datetime64D(0)