mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 17:13:37 +08:00
ENH: add an adjustment for float64 2d arrays.
BUG: fix syntax error MAINT: optimize code for cython
This commit is contained in:
@@ -3,7 +3,7 @@ from cpython cimport Py_EQ
|
||||
|
||||
from pandas import isnull, Timestamp
|
||||
from numpy cimport float64_t, uint8_t, int64_t
|
||||
from numpy import datetime64, float64
|
||||
from numpy import asarray, datetime64, float64
|
||||
# Purely for readability. There aren't C-level declarations for these types.
|
||||
ctypedef object Int64Index_t
|
||||
ctypedef object DatetimeIndex_t
|
||||
@@ -364,6 +364,83 @@ cdef class Float64Overwrite(Float64Adjustment):
|
||||
data[row, col] = value
|
||||
|
||||
|
||||
cdef class Float641DArrayOverwrite:
|
||||
"""
|
||||
An adjustment that overwrites subarrays with a value for each subarray.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(25, dtype=float).reshape(5, 5)
|
||||
>>> arr
|
||||
array([[ 0., 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8., 9.],
|
||||
[ 10., 11., 12., 13., 14.],
|
||||
[ 15., 16., 17., 18., 19.],
|
||||
[ 20., 21., 22., 23., 24.]])
|
||||
>>> adj = Float641DArrayOverwrite(
|
||||
... row_starts=np.array([0, 3]),
|
||||
... row_ends=np.array([2, 4]),
|
||||
... column_starts=np.array([0, 2]),
|
||||
... column_ends=np.array([1, 4]),
|
||||
... values=np.array([10., 20.]),
|
||||
)
|
||||
>>> adj.mutate(arr)
|
||||
>>> arr
|
||||
array([[ 10., 10., 2., 3., 4.],
|
||||
[ 10., 10., 7., 8., 9.],
|
||||
[ 10., 10., 12., 13., 14.],
|
||||
[ 15., 16., 20., 20., 20.],
|
||||
[ 20., 21., 20., 20., 20.]])
|
||||
"""
|
||||
cdef:
|
||||
readonly int64_t[:] row_starts, row_ends, column_starts, column_ends
|
||||
readonly float64_t[:] values
|
||||
|
||||
def __init__(self,
|
||||
int64_t[:] row_starts,
|
||||
int64_t[:] row_ends,
|
||||
int64_t[:] column_starts,
|
||||
int64_t[:] column_ends,
|
||||
float64_t[:] values):
|
||||
assert (len(row_starts) ==
|
||||
len(row_ends) ==
|
||||
len(column_starts) ==
|
||||
len(column_ends))
|
||||
for (row_start, row_end) in zip(row_starts, row_ends):
|
||||
assert row_start <= row_end
|
||||
for (column_start, column_end) in zip(column_starts, column_ends):
|
||||
assert column_start <= column_end
|
||||
|
||||
self.row_starts = row_starts
|
||||
self.row_ends = row_ends
|
||||
self.column_starts = column_starts
|
||||
self.column_ends = column_ends
|
||||
self.values = values
|
||||
|
||||
cpdef mutate(self, float64_t[:, :] data):
|
||||
cdef Py_ssize_t fill_range, row, col
|
||||
for fill_range in range(len(self.row_starts)):
|
||||
for row in range(self.row_starts[fill_range],
|
||||
self.row_ends[fill_range] + 1):
|
||||
for col in range(self.column_starts[fill_range],
|
||||
self.column_ends[fill_range] + 1):
|
||||
data[row, col] = self.values[fill_range]
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"%s(row_starts=%s, row_ends=%s,"
|
||||
" column_starts=%s, column_ends=%s, values=%s)" % (
|
||||
type(self).__name__,
|
||||
asarray(self.row_starts),
|
||||
asarray(self.row_ends),
|
||||
asarray(self.column_starts),
|
||||
asarray(self.column_ends),
|
||||
asarray(self.values),
|
||||
)
|
||||
)
|
||||
|
||||
cdef class Float64Add(Float64Adjustment):
|
||||
"""
|
||||
An adjustment that adds a float.
|
||||
|
||||
Reference in New Issue
Block a user