Files
scikit-image/skimage/graph/tests/test_rag.py
T
2014-08-18 18:42:23 +05:30

111 lines
3.1 KiB
Python

import numpy as np
from skimage import graph
from skimage._shared.version_requirements import is_installed
from numpy.testing.decorators import skipif
from skimage import segmentation
from numpy import testing
def max_edge(g, src, dst, n):
default = {'weight': -np.inf}
w1 = g[n].get(src, default)['weight']
w2 = g[n].get(dst, default)['weight']
return max(w1, w2)
@skipif(not is_installed('networkx'))
def test_rag_merge():
g = graph.rag.RAG()
for i in range(5):
g.add_node(i, {'labels': [i]})
g.add_edge(0, 1, {'weight': 10})
g.add_edge(1, 2, {'weight': 20})
g.add_edge(2, 3, {'weight': 30})
g.add_edge(3, 0, {'weight': 40})
g.add_edge(0, 2, {'weight': 50})
g.add_edge(3, 4, {'weight': 60})
gc = g.copy()
# We merge nodes and ensure that the minimum weight is chosen
# when there is a conflict.
g.merge_nodes(0, 2)
assert g.edge[1][2]['weight'] == 10
assert g.edge[2][3]['weight'] == 30
# We specify `max_edge` as `weight_func` as ensure that maximum
# weight is chosen in case on conflict
gc.merge_nodes(0, 2, weight_func=max_edge)
assert gc.edge[1][2]['weight'] == 20
assert gc.edge[2][3]['weight'] == 40
g.merge_nodes(1, 4)
g.merge_nodes(2, 3)
g.merge_nodes(3, 4)
assert sorted(g.node[4]['labels']) == list(range(5))
assert g.edges() == []
@skipif(not is_installed('networkx'))
def test_threshold_cut():
img = np.zeros((100, 100, 3), dtype='uint8')
img[:50, :50] = 255, 255, 255
img[:50, 50:] = 254, 254, 254
img[50:, :50] = 2, 2, 2
img[50:, 50:] = 1, 1, 1
labels = np.zeros((100, 100), dtype='uint8')
labels[:50, :50] = 0
labels[:50, 50:] = 1
labels[50:, :50] = 2
labels[50:, 50:] = 3
rag = graph.rag_mean_color(img, labels)
new_labels = graph.cut_threshold(labels, rag, 10, in_place=False)
# Two labels
assert new_labels.max() == 1
new_labels = graph.cut_threshold(labels, rag, 10)
# Two labels
assert new_labels.max() == 1
@skipif(not is_installed('networkx'))
def test_cut_normalized():
img = np.zeros((100, 100, 3), dtype='uint8')
img[:50, :50] = 255, 255, 255
img[:50, 50:] = 254, 254, 254
img[50:, :50] = 2, 2, 2
img[50:, 50:] = 1, 1, 1
labels = np.zeros((100, 100), dtype='uint8')
labels[:50, :50] = 0
labels[:50, 50:] = 1
labels[50:, :50] = 2
labels[50:, 50:] = 3
rag = graph.rag_mean_color(img, labels, mode='similarity')
new_labels = graph.cut_normalized(labels, rag, in_place=False)
new_labels, _, _ = segmentation.relabel_sequential(new_labels)
# Two labels
assert new_labels.max() == 1
new_labels = graph.cut_normalized(labels, rag)
new_labels, _, _ = segmentation.relabel_sequential(new_labels)
assert new_labels.max() == 1
@skipif(not is_installed('networkx'))
def test_rag_error():
img = np.zeros((10, 10, 3), dtype='uint8')
labels = np.zeros((10, 10), dtype='uint8')
labels[:5, :] = 0
labels[5:, :] = 1
testing.assert_raises(ValueError, graph.rag_mean_color, img, labels,
2, 'non existant mode')