mirror of
https://github.com/wassname/scikit-image.git
synced 2026-07-05 05:05:33 +08:00
fix to preserve input image shape upon executing random_walker
This commit is contained in:
@@ -334,17 +334,18 @@ def random_walker(data, labels, beta=130, mode='bf', tol=1.e-3, copy=True,
|
||||
# Parse input data
|
||||
if not multichannel:
|
||||
# We work with 4-D arrays of floats
|
||||
assert data.ndim > 1 and data.ndim < 4, 'For non-multichannel input, \
|
||||
data must be of dimension 2 \
|
||||
or 3.'
|
||||
dims = data.shape
|
||||
data = np.atleast_3d(img_as_float(data))
|
||||
data.shape += (1,)
|
||||
data = np.atleast_3d(img_as_float(data.copy()))[..., np.newaxis]
|
||||
else:
|
||||
dims = data[..., 0].shape
|
||||
assert multichannel and data.ndim > 2, 'For multichannel input, data \
|
||||
must have >= 3 dimensions.'
|
||||
data = img_as_float(data)
|
||||
data = img_as_float(data.copy())
|
||||
if data.ndim == 3:
|
||||
data.shape += (1,)
|
||||
data = data.transpose((0, 1, 3, 2))
|
||||
data = data[..., np.newaxis].transpose((0, 1, 3, 2))
|
||||
|
||||
if copy:
|
||||
labels = np.copy(labels)
|
||||
|
||||
Reference in New Issue
Block a user