mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 17:53:17 +08:00
Merge pull request #1522 from quantopian/adjusted-array-perspective-offset
MAINT: Perspective offset for load adjustments.
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Tests for chunked adjustments.
|
||||
"""
|
||||
from itertools import chain
|
||||
from collections import namedtuple
|
||||
from itertools import chain, product
|
||||
from textwrap import dedent
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -95,6 +96,20 @@ bytes_dtype = dtype('S3')
|
||||
unicode_dtype = dtype('U3')
|
||||
|
||||
|
||||
AdjustmentCase = namedtuple(
|
||||
'AdjustmentCase',
|
||||
[
|
||||
'name',
|
||||
'baseline',
|
||||
'window_length',
|
||||
'adjustments',
|
||||
'missing_value',
|
||||
'perspective_offset',
|
||||
'expected_result',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _gen_unadjusted_cases(name,
|
||||
make_input,
|
||||
make_expected_output,
|
||||
@@ -112,13 +127,14 @@ def _gen_unadjusted_cases(name,
|
||||
windowlen, nrows
|
||||
)
|
||||
|
||||
yield (
|
||||
"%s_length_%d" % (name, windowlen),
|
||||
input_array,
|
||||
windowlen,
|
||||
{},
|
||||
missing_value,
|
||||
[
|
||||
yield AdjustmentCase(
|
||||
name="%s_length_%d" % (name, windowlen),
|
||||
baseline=input_array,
|
||||
window_length=windowlen,
|
||||
adjustments={},
|
||||
missing_value=missing_value,
|
||||
perspective_offset=0,
|
||||
expected_result=[
|
||||
expected_output_array[offset:offset + windowlen]
|
||||
for offset in range(num_legal_windows)
|
||||
],
|
||||
@@ -199,6 +215,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows,
|
||||
perspective_offsets=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@@ -301,6 +318,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows=6,
|
||||
perspective_offsets=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@@ -397,13 +415,13 @@ def _gen_overwrite_1d_array_adjustment_case(dtype):
|
||||
[2, 4, 5],
|
||||
[2, 5, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
missing_value,
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows=6,
|
||||
perspective_offsets=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@@ -411,30 +429,63 @@ def _gen_expectations(baseline,
|
||||
missing_value,
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows):
|
||||
|
||||
for windowlen in valid_window_lengths(nrows):
|
||||
nrows,
|
||||
perspective_offsets):
|
||||
|
||||
for windowlen, perspective_offset in product(valid_window_lengths(nrows),
|
||||
perspective_offsets):
|
||||
# How long is an iterator of length-N windows on this buffer?
|
||||
# For example, for a window of length 3 on a buffer of length 6, there
|
||||
# are four valid windows.
|
||||
num_legal_windows = num_windows_of_length_M_on_buffers_of_length_N(
|
||||
windowlen, nrows
|
||||
)
|
||||
|
||||
yield (
|
||||
"dtype_%s_length_%d" % (baseline.dtype, windowlen),
|
||||
baseline,
|
||||
# Build the sequence of regions in the underlying buffer we expect to
|
||||
# see. For example, with a window length of 3 on a buffer of length 6,
|
||||
# we expect to see:
|
||||
# (buffer[0:3], buffer[1:4], buffer[2:5], buffer[3:6])
|
||||
#
|
||||
slices = [slice(i, i + windowlen) for i in range(num_legal_windows)]
|
||||
|
||||
# The sequence of perspectives we expect to take on the underlying
|
||||
# data. For example, with a window length of 3 and a perspective offset
|
||||
# of 1, we expect to see:
|
||||
# (buffer_as_of[3], buffer_as_of[4], buffer_as_of[5], buffer_as_of[5])
|
||||
#
|
||||
initial_perspective = windowlen + perspective_offset - 1
|
||||
perspectives = range(
|
||||
initial_perspective,
|
||||
initial_perspective + num_legal_windows
|
||||
)
|
||||
|
||||
def as_of(p):
|
||||
# perspective_offset can push us past the end of the underlying
|
||||
# buffer/adjustments. When it does, we should always see the latest
|
||||
# version of the buffer.
|
||||
if p >= len(buffer_as_of):
|
||||
return buffer_as_of[-1]
|
||||
return buffer_as_of[p]
|
||||
|
||||
expected_iterator_results = [
|
||||
as_of(perspective)[slice_]
|
||||
for slice_, perspective in zip(slices, perspectives)
|
||||
]
|
||||
|
||||
test_name = "dtype_{}_length_{}_perpective_offset_{}".format(
|
||||
baseline.dtype,
|
||||
windowlen,
|
||||
adjustments,
|
||||
missing_value,
|
||||
[
|
||||
# This is a nasty expression...
|
||||
#
|
||||
# Reading from right to left: we want a slice of length
|
||||
# 'windowlen', starting at 'offset', from the buffer on which
|
||||
# we've applied all adjustments corresponding to the last row
|
||||
# of the data, which will be (offset + windowlen - 1).
|
||||
buffer_as_of[offset + windowlen - 1][offset:offset + windowlen]
|
||||
for offset in range(num_legal_windows)
|
||||
],
|
||||
perspective_offset,
|
||||
)
|
||||
|
||||
yield AdjustmentCase(
|
||||
name=test_name,
|
||||
baseline=baseline,
|
||||
window_length=windowlen,
|
||||
adjustments=adjustments,
|
||||
missing_value=missing_value,
|
||||
perspective_offset=perspective_offset,
|
||||
expected_result=expected_iterator_results
|
||||
)
|
||||
|
||||
|
||||
@@ -504,6 +555,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
perspective_offset,
|
||||
expected_output):
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
@@ -519,11 +571,15 @@ class AdjustedArrayTestCase(TestCase):
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
perspective_offset,
|
||||
expected):
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
window_iter = array.traverse(
|
||||
lookback,
|
||||
perspective_offset=perspective_offset,
|
||||
)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
check_arrays(yielded, expected_yield)
|
||||
|
||||
@@ -584,14 +640,19 @@ class AdjustedArrayTestCase(TestCase):
|
||||
)
|
||||
def test_overwrite_adjustment_cases(self,
|
||||
name,
|
||||
data,
|
||||
baseline,
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
perspective_offset,
|
||||
expected):
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
array = AdjustedArray(baseline, NOMASK, adjustments, missing_value)
|
||||
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
window_iter = array.traverse(
|
||||
lookback,
|
||||
perspective_offset=perspective_offset,
|
||||
)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
check_arrays(yielded, expected_yield)
|
||||
|
||||
|
||||
@@ -107,8 +107,7 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
def _array(self, start, end, assets, field):
|
||||
pass
|
||||
|
||||
def _get_adjustments_in_range(self, asset, dts, field,
|
||||
is_perspective_after):
|
||||
def _get_adjustments_in_range(self, asset, dts, field):
|
||||
"""
|
||||
Get the Float64Multiply objects to pass to an AdjustedArrayWindow.
|
||||
|
||||
@@ -154,11 +153,6 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
if start < dt <= end:
|
||||
end_loc = dts.searchsorted(dt)
|
||||
adj_loc = end_loc
|
||||
if is_perspective_after:
|
||||
# Set adjustment pop location so that it applies
|
||||
# to last value if adjustment occurs immediately after
|
||||
# the last slot.
|
||||
adj_loc -= 1
|
||||
mult = Float64Multiply(0,
|
||||
end_loc - 1,
|
||||
0,
|
||||
@@ -175,11 +169,6 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
if start < dt <= end:
|
||||
end_loc = dts.searchsorted(dt)
|
||||
adj_loc = end_loc
|
||||
if is_perspective_after:
|
||||
# Set adjustment pop location so that it applies
|
||||
# to last value if adjustment occurs immediately after
|
||||
# the last slot.
|
||||
adj_loc -= 1
|
||||
mult = Float64Multiply(0,
|
||||
end_loc - 1,
|
||||
0,
|
||||
@@ -200,11 +189,6 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
ratio = s[1]
|
||||
end_loc = dts.searchsorted(dt)
|
||||
adj_loc = end_loc
|
||||
if is_perspective_after:
|
||||
# Set adjustment pop location so that it applies
|
||||
# to last value if adjustment occurs immediately after
|
||||
# the last slot.
|
||||
adj_loc -= 1
|
||||
mult = Float64Multiply(0,
|
||||
end_loc - 1,
|
||||
0,
|
||||
@@ -284,7 +268,7 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
for i, asset in enumerate(needed_assets):
|
||||
if self._adjustments_reader:
|
||||
adjs = self._get_adjustments_in_range(
|
||||
asset, prefetch_dts, field, is_perspective_after)
|
||||
asset, prefetch_dts, field)
|
||||
else:
|
||||
adjs = {}
|
||||
window = window_type(
|
||||
@@ -292,7 +276,8 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
view_kwargs,
|
||||
adjs,
|
||||
offset,
|
||||
size
|
||||
size,
|
||||
int(is_perspective_after)
|
||||
)
|
||||
sliding_window = SlidingWindow(window, size, start_ix, offset)
|
||||
asset_windows[asset] = sliding_window
|
||||
|
||||
@@ -35,6 +35,7 @@ cdef class AdjustedArrayWindow:
|
||||
readonly dict view_kwargs
|
||||
readonly Py_ssize_t window_length
|
||||
Py_ssize_t anchor, next_anchor, max_anchor, next_adj
|
||||
Py_ssize_t perspective_offset
|
||||
dict adjustments
|
||||
list adjustment_indices
|
||||
ndarray last_out
|
||||
@@ -44,14 +45,24 @@ cdef class AdjustedArrayWindow:
|
||||
dict view_kwargs not None,
|
||||
dict adjustments not None,
|
||||
Py_ssize_t offset,
|
||||
Py_ssize_t window_length):
|
||||
|
||||
Py_ssize_t window_length,
|
||||
Py_ssize_t perspective_offset):
|
||||
self.data = data
|
||||
self.view_kwargs = view_kwargs
|
||||
self.adjustments = adjustments
|
||||
self.adjustment_indices = sorted(adjustments, reverse=True)
|
||||
self.window_length = window_length
|
||||
self.anchor = window_length + offset
|
||||
if perspective_offset > 1:
|
||||
# Limit perspective_offset to 1.
|
||||
# To support an offset greater than 1, work must be done to
|
||||
# ensure that adjustments are retrieved for the datetimes between
|
||||
# the end of the window and the vantage point defined by the
|
||||
# perspective offset.
|
||||
raise Exception("perspective_offset should not exceed 1, value "
|
||||
"is perspective_offset={0}".format(
|
||||
perspective_offset))
|
||||
self.perspective_offset = perspective_offset
|
||||
self.next_anchor = self.anchor
|
||||
self.max_anchor = data.shape[0]
|
||||
|
||||
@@ -65,7 +76,7 @@ cdef class AdjustedArrayWindow:
|
||||
if len(self.adjustment_indices) > 0:
|
||||
return self.adjustment_indices.pop()
|
||||
else:
|
||||
return self.max_anchor
|
||||
return self.max_anchor + self.perspective_offset
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@@ -84,7 +95,7 @@ cdef class AdjustedArrayWindow:
|
||||
# Apply any adjustments that occured before our current anchor.
|
||||
# Equivalently, apply any adjustments known **on or before** the date
|
||||
# for which we're calculating a window.
|
||||
while self.next_adj < anchor:
|
||||
while self.next_adj < anchor + self.perspective_offset:
|
||||
|
||||
for adjustment in self.adjustments[self.next_adj]:
|
||||
adjustment.mutate(self.data)
|
||||
|
||||
@@ -200,7 +200,10 @@ class AdjustedArray(object):
|
||||
return LabelWindow
|
||||
return CONCRETE_WINDOW_TYPES[self._data.dtype]
|
||||
|
||||
def traverse(self, window_length, offset=0):
|
||||
def traverse(self,
|
||||
window_length,
|
||||
offset=0,
|
||||
perspective_offset=0):
|
||||
"""
|
||||
Produce an iterator rolling windows rows over our data.
|
||||
Each emitted window will have `window_length` rows.
|
||||
@@ -210,7 +213,10 @@ class AdjustedArray(object):
|
||||
window_length : int
|
||||
The number of rows in each emitted window.
|
||||
offset : int, optional
|
||||
Number of rows to skip before the first window.
|
||||
Number of rows to skip before the first window. Default is 0.
|
||||
perspective_offset : int, optional
|
||||
Number of rows past the end of the current window from which to
|
||||
"view" the underlying data.
|
||||
"""
|
||||
data = self._data.copy()
|
||||
_check_window_params(data, window_length)
|
||||
@@ -220,6 +226,7 @@ class AdjustedArray(object):
|
||||
self.adjustments,
|
||||
offset,
|
||||
window_length,
|
||||
perspective_offset,
|
||||
)
|
||||
|
||||
def inspect(self):
|
||||
|
||||
Reference in New Issue
Block a user