diff --git a/skimage/segmentation/random_walker_segmentation.py b/skimage/segmentation/random_walker_segmentation.py index 6df714d2..3e1c755f 100644 --- a/skimage/segmentation/random_walker_segmentation.py +++ b/skimage/segmentation/random_walker_segmentation.py @@ -334,17 +334,18 @@ def random_walker(data, labels, beta=130, mode='bf', tol=1.e-3, copy=True, # Parse input data if not multichannel: # We work with 4-D arrays of floats + assert data.ndim > 1 and data.ndim < 4, 'For non-multichannel input, \ + data must be of dimension 2 \ + or 3.' dims = data.shape - data = np.atleast_3d(img_as_float(data)) - data.shape += (1,) + data = np.atleast_3d(img_as_float(data.copy()))[..., np.newaxis] else: dims = data[..., 0].shape assert multichannel and data.ndim > 2, 'For multichannel input, data \ must have >= 3 dimensions.' - data = img_as_float(data) + data = img_as_float(data.copy()) if data.ndim == 3: - data.shape += (1,) - data = data.transpose((0, 1, 3, 2)) + data = data[..., np.newaxis].transpose((0, 1, 3, 2)) if copy: labels = np.copy(labels)