Added and tested proper exception handling concerning array shape

This commit is contained in:
Matěj Týč
2014-11-16 15:33:13 +01:00
parent eb55d177c3
commit abf3dbc593
2 changed files with 14 additions and 20 deletions
+10 -20
View File
@@ -45,9 +45,7 @@ ctypedef struct bginfo:
DTYPE_t background_label
cdef bginfo get_bginfo(background_val):
cdef bginfo ret
cdef void get_bginfo(background_val, bginfo *ret) except *:
if background_val is None:
warnings.warn(DeprecationWarning(
'The default value for `background` will change to 0 in v0.12'
@@ -60,7 +58,6 @@ cdef bginfo get_bginfo(background_val):
# upon the first background pixel occurence
ret.background_node = -999
ret.background_label = -1
return ret
# A pixel has neighbors that have already been scanned.
@@ -110,13 +107,11 @@ cdef struct s_shpinfo:
fun_ravel ravel_index
cdef shape_info get_shape_info(inarr_shape):
cdef void get_shape_info(inarr_shape, shape_info *res) except *:
"""
Precalculates all the needed data from the input array shape
and stores them in the shape_info struct.
"""
cdef shape_info res
res.y = 1
res.z = 1
res.ravel_index = ravel_index2D
@@ -135,7 +130,9 @@ cdef shape_info get_shape_info(inarr_shape):
res.z = inarr_shape[0]
res.ravel_index = ravel_index3D
else:
assert "Only for images of dimension 1-3 (got %s)" % res.ndim
raise NotImplementedError(
"Only for images of dimension 1-3 are supported, got a %sD one"
% res.ndim)
res.numels = res.x * res.y * res.z
@@ -152,12 +149,12 @@ cdef shape_info get_shape_info(inarr_shape):
# res.DEX[D_ee] = 0
# So now the 2nd (needed for 2D and 3D) part, y = 0, z = 1
res.DEX[D_ea] = res.ravel_index(-1, -1, 0, & res)
res.DEX[D_ea] = res.ravel_index(-1, -1, 0, res)
res.DEX[D_eb] = res.DEX[D_ea] + 1
res.DEX[D_ec] = res.DEX[D_eb] + 1
# And now the 3rd (needed only for 3D) part, z = 0
res.DEX[D_ef] = res.ravel_index(-1, -1, -1, & res)
res.DEX[D_ef] = res.ravel_index(-1, -1, -1, res)
res.DEX[D_eg] = res.DEX[D_ef] + 1
res.DEX[D_eh] = res.DEX[D_ef] + 2
res.DEX[D_ei] = res.DEX[D_ef] - res.DEX[D_eb] # DEX[D_eb] = one row up, remember?
@@ -167,8 +164,6 @@ cdef shape_info get_shape_info(inarr_shape):
res.DEX[D_em] = res.DEX[D_el] + 1
res.DEX[D_en] = res.DEX[D_el] + 2
return res
cdef inline void join_trees_wrapper(DTYPE_t *data_p, DTYPE_t *forest_p,
DTYPE_t rindex, INTS_t idxdiff):
@@ -354,19 +349,14 @@ def label(input, DTYPE_t neighbors=8, background=None, return_num=False):
cdef shape_info shapeinfo
cdef bginfo bg
shapeinfo = get_shape_info(input.shape)
bg = get_bginfo(background)
get_shape_info(input.shape, &shapeinfo)
get_bginfo(background, &bg)
if neighbors != 4 and neighbors != 8:
raise ValueError('Neighbors must be either 4 or 8.')
scanBG(data_p, forest_p, & shapeinfo, & bg)
if shapeinfo.ndim == 1:
scan1D(data_p, forest_p, & shapeinfo, & bg, neighbors, 0, 0)
elif shapeinfo.ndim == 2:
scan2D(data_p, forest_p, & shapeinfo, & bg, neighbors, 0)
elif shapeinfo.ndim == 3:
scan3D(data_p, forest_p, & shapeinfo, & bg, neighbors)
scan3D(data_p, forest_p, & shapeinfo, & bg, neighbors)
# Label output
cdef DTYPE_t ctr
+4
View File
@@ -244,6 +244,10 @@ class TestConnectedComponents3d:
assert_array_equal(label(x, background=0, return_num=True)[1], 3)
def test_nd(self):
x = np.ones((1, 2, 3, 4))
np.testing.assert_raises(NotImplementedError, label, x)
if __name__ == "__main__":
run_module_suite()