mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 22:01:30 +08:00
PERF: Refactor AdjustedArrayWindow.
Make `__next__` and `seek` share code instead of seek() calling `__next__`. This avoids having to make a large number of integer comparisons and `asanyarray` calls when seeking more than one tick forward.
This commit is contained in:
@@ -15,6 +15,10 @@ from numpy cimport ndarray
|
||||
from numpy import asanyarray
|
||||
|
||||
|
||||
class Exhausted(Exception):
|
||||
pass
|
||||
|
||||
|
||||
cdef class AdjustedArrayWindow:
|
||||
"""
|
||||
An iterator representing a moving view over an AdjustedArray.
|
||||
@@ -34,11 +38,11 @@ cdef class AdjustedArrayWindow:
|
||||
readonly databuffer data
|
||||
readonly dict view_kwargs
|
||||
readonly Py_ssize_t window_length
|
||||
Py_ssize_t anchor, next_anchor, max_anchor, next_adj
|
||||
Py_ssize_t anchor, max_anchor, next_adj
|
||||
Py_ssize_t perspective_offset
|
||||
dict adjustments
|
||||
list adjustment_indices
|
||||
ndarray last_out
|
||||
ndarray output
|
||||
|
||||
def __cinit__(self,
|
||||
databuffer data not None,
|
||||
@@ -52,7 +56,7 @@ cdef class AdjustedArrayWindow:
|
||||
self.adjustments = adjustments
|
||||
self.adjustment_indices = sorted(adjustments, reverse=True)
|
||||
self.window_length = window_length
|
||||
self.anchor = window_length + offset
|
||||
self.anchor = window_length + offset - 1
|
||||
if perspective_offset > 1:
|
||||
# Limit perspective_offset to 1.
|
||||
# To support an offset greater than 1, work must be done to
|
||||
@@ -63,11 +67,10 @@ cdef class AdjustedArrayWindow:
|
||||
"is perspective_offset={0}".format(
|
||||
perspective_offset))
|
||||
self.perspective_offset = perspective_offset
|
||||
self.next_anchor = self.anchor
|
||||
self.max_anchor = data.shape[0]
|
||||
|
||||
self.next_adj = self.pop_next_adj()
|
||||
self.last_out = None
|
||||
self.output = None
|
||||
|
||||
cdef pop_next_adj(self):
|
||||
"""
|
||||
@@ -82,54 +85,61 @@ cdef class AdjustedArrayWindow:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
self._tick_forward(1)
|
||||
except Exhausted:
|
||||
raise StopIteration()
|
||||
|
||||
self._update_output()
|
||||
return self.output
|
||||
|
||||
def seek(self, Py_ssize_t target_anchor):
|
||||
cdef:
|
||||
Py_ssize_t anchor = self.anchor
|
||||
|
||||
if target_anchor < anchor:
|
||||
raise Exception('Can not access data after window has passed.')
|
||||
|
||||
if target_anchor == anchor:
|
||||
return self.output
|
||||
|
||||
self._tick_forward(target_anchor - anchor)
|
||||
self._update_output()
|
||||
|
||||
return self.output
|
||||
|
||||
cdef inline _tick_forward(self, int N):
|
||||
cdef:
|
||||
object adjustment
|
||||
ndarray out
|
||||
Py_ssize_t start, anchor
|
||||
dict view_kwargs
|
||||
Py_ssize_t anchor = self.anchor
|
||||
Py_ssize_t target = anchor + N
|
||||
|
||||
anchor = self.anchor = self.next_anchor
|
||||
if anchor > self.max_anchor:
|
||||
raise StopIteration()
|
||||
if target > self.max_anchor:
|
||||
raise Exhausted()
|
||||
|
||||
# Apply any adjustments that occured before our current anchor.
|
||||
# Equivalently, apply any adjustments known **on or before** the date
|
||||
# for which we're calculating a window.
|
||||
while self.next_adj < anchor + self.perspective_offset:
|
||||
while self.next_adj < target + self.perspective_offset:
|
||||
|
||||
for adjustment in self.adjustments[self.next_adj]:
|
||||
adjustment.mutate(self.data)
|
||||
|
||||
self.next_adj = self.pop_next_adj()
|
||||
|
||||
start = anchor - self.window_length
|
||||
self.anchor = target
|
||||
|
||||
# If our data is a custom subclass of ndarray, preserve that subclass
|
||||
# by using asanyarray instead of asarray.
|
||||
out = asanyarray(self.data[start:self.anchor])
|
||||
view_kwargs = self.view_kwargs
|
||||
cdef inline _update_output(self):
|
||||
cdef:
|
||||
ndarray new_out
|
||||
Py_ssize_t anchor = self.anchor
|
||||
dict view_kwargs = self.view_kwargs
|
||||
|
||||
new_out = asanyarray(self.data[anchor - self.window_length:anchor])
|
||||
if view_kwargs:
|
||||
out = out.view(**view_kwargs)
|
||||
out.setflags(write=False)
|
||||
|
||||
self.next_anchor = self.anchor + 1
|
||||
self.last_out = out
|
||||
return out
|
||||
|
||||
def seek(self, target_anchor):
|
||||
cdef ndarray out = None
|
||||
|
||||
if target_anchor < self.anchor:
|
||||
raise Exception('Can not access data after window has passed.')
|
||||
|
||||
if target_anchor == self.anchor:
|
||||
return self.last_out
|
||||
|
||||
while self.anchor < target_anchor:
|
||||
out = next(self)
|
||||
|
||||
self.last_out = out
|
||||
return out
|
||||
new_out = new_out.view(**view_kwargs)
|
||||
new_out.setflags(write=False)
|
||||
self.output = new_out
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: window_length=%d, anchor=%d, max_anchor=%d, dtype=%r>" % (
|
||||
|
||||
Reference in New Issue
Block a user