diff --git a/skimage/graph/graph_cut.py b/skimage/graph/graph_cut.py index 3f973e22..d7710f0b 100644 --- a/skimage/graph/graph_cut.py +++ b/skimage/graph/graph_cut.py @@ -72,7 +72,8 @@ def cut_threshold(labels, rag, thresh, in_place=True): return map_array[labels] -def cut_normalized(labels, rag, thresh=0.001, num_cuts=10, in_place=True): +def cut_normalized(labels, rag, thresh=0.001, num_cuts=10, in_place=True, + max_edge=1.0): """Perform Normalized Graph cut on the Region Adjacency Graph. Given an image's labels and its similarity RAG, recursively perform @@ -94,6 +95,10 @@ def cut_normalized(labels, rag, thresh=0.001, num_cuts=10, in_place=True): in_place : bool If set, modifies `rag` in place. For each node `n` the function will set a new attribute ``rag.node[n]['ncut label']``. + max_edge : float, optional + The maximum possible value of an edge in the RAG. This corresponds to + an edge between identical regions. This is used to put self + edges in the RAG. Returns ------- @@ -118,6 +123,9 @@ def cut_normalized(labels, rag, thresh=0.001, num_cuts=10, in_place=True): if not in_place: rag = rag.copy() + for node in rag.nodes_iter(): + rag.add_edge(node, node, weight=max_edge) + _ncut_relabel(rag, thresh, num_cuts) map_array = np.zeros(labels.max() + 1) diff --git a/skimage/graph/rag.py b/skimage/graph/rag.py index 5ef9a83b..df9b4ebf 100644 --- a/skimage/graph/rag.py +++ b/skimage/graph/rag.py @@ -115,7 +115,8 @@ def _add_edge_filter(values, graph): values = values.astype(int) current = values[0] for value in values[1:]: - graph.add_edge(current, value) + if value != current: + graph.add_edge(current, value) return 0 diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 0e62aa1a..26f5a85c 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -104,5 +104,7 @@ def test_cut_normalized(): def test_rag_error(): img = np.zeros((10, 10, 3), dtype='uint8') labels = np.zeros((10, 10), dtype='uint8') + labels[:5, :] = 0 + labels[5:, :] = 1 testing.assert_raises(ValueError, graph.rag_mean_color, img, labels, 2, 'non existant mode')