fix to preserve input image shape upon executing random_walker

This commit is contained in:
Josh Warner (Mac)
2013-04-11 11:49:10 -05:00
parent edfce92f30
commit 1edbf3e6f4
@@ -334,17 +334,18 @@ def random_walker(data, labels, beta=130, mode='bf', tol=1.e-3, copy=True,
# Parse input data
if not multichannel:
# We work with 4-D arrays of floats
assert data.ndim > 1 and data.ndim < 4, 'For non-multichannel input, \
data must be of dimension 2 \
or 3.'
dims = data.shape
data = np.atleast_3d(img_as_float(data))
data.shape += (1,)
data = np.atleast_3d(img_as_float(data.copy()))[..., np.newaxis]
else:
dims = data[..., 0].shape
assert multichannel and data.ndim > 2, 'For multichannel input, data \
must have >= 3 dimensions.'
data = img_as_float(data)
data = img_as_float(data.copy())
if data.ndim == 3:
data.shape += (1,)
data = data.transpose((0, 1, 3, 2))
data = data[..., np.newaxis].transpose((0, 1, 3, 2))
if copy:
labels = np.copy(labels)