mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-29 05:17:50 +08:00
6635cf16db
* when returning the full probability instead of the segmentation
* for three or more labels (two was OK)
175 lines
5.9 KiB
Python
175 lines
5.9 KiB
Python
import numpy as np
|
|
from skimage.segmentation import random_walker
|
|
try:
|
|
import pyamg
|
|
amg_loaded = True
|
|
except ImportError:
|
|
amg_loaded = False
|
|
|
|
|
|
def make_2d_syntheticdata(lx, ly=None):
|
|
if ly is None:
|
|
ly = lx
|
|
data = np.zeros((lx, ly)) + 0.1 * np.random.randn(lx, ly)
|
|
small_l = int(lx / 5)
|
|
data[lx / 2 - small_l:lx / 2 + small_l,
|
|
ly / 2 - small_l:ly / 2 + small_l] = 1
|
|
data[lx / 2 - small_l + 1:lx / 2 + small_l - 1,
|
|
ly / 2 - small_l + 1:ly / 2 + small_l - 1] = \
|
|
0.1 * np.random.randn(2 * small_l - 2, 2 * small_l - 2)
|
|
data[lx / 2 - small_l, ly / 2 - small_l / 8:ly / 2 + small_l / 8] = 0
|
|
seeds = np.zeros_like(data)
|
|
seeds[lx / 5, ly / 5] = 1
|
|
seeds[lx / 2 + small_l / 4, ly / 2 - small_l / 4] = 2
|
|
return data, seeds
|
|
|
|
|
|
def make_3d_syntheticdata(lx, ly=None, lz=None):
|
|
if ly is None:
|
|
ly = lx
|
|
if lz is None:
|
|
lz = lx
|
|
data = np.zeros((lx, ly, lz)) + 0.1 * np.random.randn(lx, ly, lz)
|
|
small_l = int(lx / 5)
|
|
data[lx / 2 - small_l:lx / 2 + small_l,
|
|
ly / 2 - small_l:ly / 2 + small_l,
|
|
lz / 2 - small_l:lz / 2 + small_l] = 1
|
|
data[lx / 2 - small_l + 1:lx / 2 + small_l - 1,
|
|
ly / 2 - small_l + 1:ly / 2 + small_l - 1,
|
|
lz / 2 - small_l + 1:lz / 2 + small_l - 1] = 0
|
|
# make a hole
|
|
hole_size = np.max([1, small_l / 8])
|
|
data[lx / 2 - small_l,
|
|
ly / 2 - hole_size:ly / 2 + hole_size,\
|
|
lz / 2 - hole_size:lz / 2 + hole_size] = 0
|
|
seeds = np.zeros_like(data)
|
|
seeds[lx / 5, ly / 5, lz / 5] = 1
|
|
seeds[lx / 2 + small_l / 4, ly / 2 - small_l / 4, lz / 2 - small_l / 4] = 2
|
|
return data, seeds
|
|
|
|
|
|
def test_2d_bf():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
labels_bf = random_walker(data, labels, beta=90, mode='bf')
|
|
assert (labels_bf[25:45, 40:60] == 2).all()
|
|
full_prob_bf = random_walker(data, labels, beta=90, mode='bf',
|
|
return_full_prob=True)
|
|
assert (full_prob_bf[1, 25:45, 40:60] >=
|
|
full_prob_bf[0, 25:45, 40:60]).all()
|
|
# Now test with more than two labels
|
|
labels[55, 80] = 3
|
|
full_prob_bf = random_walker(data, labels, beta=90, mode='bf',
|
|
return_full_prob=True)
|
|
assert (full_prob_bf[1, 25:45, 40:60] >=
|
|
full_prob_bf[0, 25:45, 40:60]).all()
|
|
assert len(full_prob_bf) == 3
|
|
|
|
def test_2d_cg():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
labels_cg = random_walker(data, labels, beta=90, mode='cg')
|
|
assert (labels_cg[25:45, 40:60] == 2).all()
|
|
full_prob = random_walker(data, labels, beta=90, mode='cg',
|
|
return_full_prob=True)
|
|
assert (full_prob[1, 25:45, 40:60] >=
|
|
full_prob[0, 25:45, 40:60]).all()
|
|
return data, labels_cg
|
|
|
|
|
|
def test_2d_cg_mg():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
labels_cg_mg = random_walker(data, labels, beta=90, mode='cg_mg')
|
|
assert (labels_cg_mg[25:45, 40:60] == 2).all()
|
|
full_prob = random_walker(data, labels, beta=90, mode='cg_mg',
|
|
return_full_prob=True)
|
|
assert (full_prob[1, 25:45, 40:60] >=
|
|
full_prob[0, 25:45, 40:60]).all()
|
|
return data, labels_cg_mg
|
|
|
|
|
|
def test_types():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
data = 255 * (data - data.min()) / (data.max() - data.min())
|
|
data = data.astype(np.uint8)
|
|
labels_cg_mg = random_walker(data, labels, beta=90, mode='cg_mg')
|
|
assert (labels_cg_mg[25:45, 40:60] == 2).all()
|
|
return data, labels_cg_mg
|
|
|
|
|
|
def test_reorder_labels():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
labels[labels == 2] = 4
|
|
labels_bf = random_walker(data, labels, beta=90, mode='bf')
|
|
assert (labels_bf[25:45, 40:60] == 2).all()
|
|
return data, labels_bf
|
|
|
|
|
|
def test_2d_inactive():
|
|
lx = 70
|
|
ly = 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
labels[10:20, 10:20] = -1
|
|
labels[46:50, 33:38] = -2
|
|
labels = random_walker(data, labels, beta=90)
|
|
assert (labels.reshape((lx, ly))[25:45, 40:60] == 2).all()
|
|
return data, labels
|
|
|
|
|
|
def test_3d():
|
|
n = 30
|
|
lx, ly, lz = n, n, n
|
|
data, labels = make_3d_syntheticdata(lx, ly, lz)
|
|
labels = random_walker(data, labels, mode='cg')
|
|
assert (labels.reshape(data.shape)[13:17, 13:17, 13:17] == 2).all()
|
|
return data, labels
|
|
|
|
|
|
def test_3d_inactive():
|
|
n = 30
|
|
lx, ly, lz = n, n, n
|
|
data, labels = make_3d_syntheticdata(lx, ly, lz)
|
|
old_labels = np.copy(labels)
|
|
labels[5:25, 26:29, 26:29] = -1
|
|
after_labels = np.copy(labels)
|
|
labels = random_walker(data, labels, mode='cg')
|
|
assert (labels.reshape(data.shape)[13:17, 13:17, 13:17] == 2).all()
|
|
return data, labels, old_labels, after_labels
|
|
|
|
|
|
def test_multispectral_2d():
|
|
lx, ly = 70, 100
|
|
data, labels = make_2d_syntheticdata(lx, ly)
|
|
data2 = data.copy()
|
|
data.shape += (1,)
|
|
data = data.repeat(2, axis=2) # Result should be identical
|
|
multi_labels = random_walker(data, labels, mode='cg', multichannel=True)
|
|
single_labels = random_walker(data2, labels, mode='cg')
|
|
assert (multi_labels.reshape(labels.shape)[25:45, 40:60] == 2).all()
|
|
return data, multi_labels, single_labels, labels
|
|
|
|
|
|
def test_multispectral_3d():
|
|
n = 30
|
|
lx, ly, lz = n, n, n
|
|
data, labels = make_3d_syntheticdata(lx, ly, lz)
|
|
data.shape += (1,)
|
|
data = data.repeat(2, axis=3) # Result should be identical
|
|
multi_labels = random_walker(data, labels, mode='cg', multichannel=True)
|
|
single_labels = random_walker(data[..., 0], labels, mode='cg')
|
|
assert (multi_labels.reshape(labels.shape)[13:17, 13:17, 13:17] == 2).all()
|
|
assert (single_labels.reshape(labels.shape)[13:17, 13:17, 13:17] == 2).all()
|
|
return data, multi_labels, single_labels, labels
|
|
|
|
if __name__ == '__main__':
|
|
from numpy import testing
|
|
testing.run_module_suite()
|