From 4c4d7108e149fbdaa741d665ae152ea8297b576a Mon Sep 17 00:00:00 2001 From: Vighnesh Birodkar Date: Wed, 27 Aug 2014 21:09:25 +0530 Subject: [PATCH] handled the function with callbacks --- doc/examples/plot_rag_merge.py | 6 +-- skimage/graph/__init__.py | 3 +- skimage/graph/graph_merge.py | 76 ++++++++++++++++++++++++++------- skimage/graph/tests/test_rag.py | 2 +- 4 files changed, 67 insertions(+), 20 deletions(-) diff --git a/doc/examples/plot_rag_merge.py b/doc/examples/plot_rag_merge.py index 00236d95..d15ee954 100644 --- a/doc/examples/plot_rag_merge.py +++ b/doc/examples/plot_rag_merge.py @@ -5,8 +5,8 @@ RAG Merging This example constructs a Region Adjacency Graph (RAG) and progressively merges regions that are similar in color. Merging two adjacent regions produces -a new regions with all the pixels from the merged regions. Regions are merged -until no highly similar regions remain. +a new region with all the pixels from the merged regions. Regions are merged +until no highly similar region pairs remain. """ @@ -16,7 +16,7 @@ from skimage import graph, data, io, segmentation, color img = data.coffee() labels = segmentation.slic(img, compactness=30, n_segments=400) g = graph.rag_mean_color(img, labels) -labels2 = graph.merge_hierarchical(labels, g, 40) +labels2 = graph.merge_hierarchical_mean_color(labels, g, 40) g2 = graph.rag_mean_color(img, labels2) out = color.label2rgb(labels2, img, kind='avg') diff --git a/skimage/graph/__init__.py b/skimage/graph/__init__.py index f93708c9..7122bb3c 100644 --- a/skimage/graph/__init__.py +++ b/skimage/graph/__init__.py @@ -2,7 +2,7 @@ from .spath import shortest_path from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array from .graph_cut import cut_threshold, cut_normalized from .rag import rag_mean_color, RAG, draw_rag -from .graph_merge import merge_hierarchical +from .graph_merge import merge_hierarchical, merge_hierarchical_mean_color ncut = cut_normalized @@ -18,4 +18,5 @@ __all__ = ['shortest_path', 'ncut', 'draw_rag', 'merge_hierarchical', + 'merge_hierarchical_mean_color', 'RAG'] diff --git a/skimage/graph/graph_merge.py b/skimage/graph/graph_merge.py index cd103325..077d9723 100644 --- a/skimage/graph/graph_merge.py +++ b/skimage/graph/graph_merge.py @@ -2,7 +2,7 @@ import numpy as np import heapq -def _hmerge_mean_color(graph, src, dst, n): +def _weight_mean_color(graph, src, dst, n): """Callback to handle merging nodes by recomputing mean color. The method expects that the mean color of `dst` is already computed. @@ -26,6 +26,24 @@ def _hmerge_mean_color(graph, src, dst, n): return diff +def _pre_merge_mean_color(graph, src, dst): + """Callback called before merging two nodes of a mean color distance graph. + + This method computes the mean color of `dst`. + + Parameters + ---------- + graph : RAG + The graph under consideration. + src, dst : int + The vertices in `graph` to be merged. + """ + graph.node[dst]['total color'] += graph.node[src]['total color'] + graph.node[dst]['pixel count'] += graph.node[src]['pixel count'] + graph.node[dst]['mean color'] = (graph.node[dst]['total color'] / + graph.node[dst]['pixel count']) + + def _revalidate_node_edges(rag, node, heap_list): """Handles validation and invalidation of edges incident to a node. @@ -61,8 +79,10 @@ def _revalidate_node_edges(rag, node, heap_list): heapq.heappush(heap_list, heap_item) -def merge_hierarchical(labels, rag, thresh, in_place=True): - """Perform hierarchical merging of a RAG. +def merge_hierarchical_mean_color(labels, rag, thresh, in_place=True): + return merge_hierarchical(labels, rag, thresh, in_place, + _pre_merge_mean_color, _weight_mean_color) + """Perform hierarchical merging of a color distance RAG. Greedily merges the most similar pair of nodes until no edges lower than `thresh` remain. @@ -79,18 +99,48 @@ def merge_hierarchical(labels, rag, thresh, in_place=True): in_place : bool, optional If set, the RAG is modified in place. - Returns - ------- - out : ndarray - The new labeled array. - Examples -------- >>> from skimage import data, graph, segmentation >>> img = data.coffee() >>> labels = segmentation.slic(img) >>> rag = graph.rag_mean_color(img, labels) - >>> new_labels = graph.merge_hierarchical(labels, rag, 40) + >>> new_labels = graph.merge_hierarchical_mean_color(labels, rag, 40) + """ + + +def merge_hierarchical(labels, rag, thresh, in_place, pre_merge_func, + weight_func): + """Perform hierarchical merging of a RAG. + + Greedily merges the most similar pair of nodes until no edges lower than + `thresh` remain. + + Parameters + ---------- + labels : ndarray + The array of labels. + rag : RAG + The Region Adjacency Graph. + thresh : float + Regions connected by an edge with weight smaller than `thresh` are + merged. + in_place : bool, optional + If set, the RAG is modified in place. + pre_merge_func : callable + This function is called before merging two nodes. For the RAG `graph` + while merging `src` and `dst`, it is called as follows + ``pre_merge_func(graph, src, dst)``. + weight_func : callable + The function to compute the new weights of the nodes adjacent to the + merged node. This is directly supplied as the argument `weight_func` + to `merge_nodes`. + + Returns + ------- + out : ndarray + The new labeled array. + """ if not in_place: rag = rag.copy() @@ -110,16 +160,12 @@ def merge_hierarchical(labels, rag, thresh, in_place=True): # Ensure popped edge is valid, if not, the edge is discarded if valid: - rag.node[dst]['total color'] += rag.node[src]['total color'] - rag.node[dst]['pixel count'] += rag.node[src]['pixel count'] - rag.node[dst]['mean color'] = (rag.node[dst]['total color'] / - rag.node[dst]['pixel count']) - + _pre_merge_mean_color(rag, src, dst) # Invalidate all neigbors of `src` before its deleted for n in rag.neighbors(src): rag[src][n]['heap item'][3] = False - rag.merge_nodes(src, dst, _hmerge_mean_color) + rag.merge_nodes(src, dst, _weight_mean_color) _revalidate_node_edges(rag, dst, edge_heap) arr = np.arange(labels.max() + 1) diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 4e287ceb..ba9ef994 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -124,7 +124,7 @@ def test_merge_hierarchical(): labels[50:, 50:] = 3 rag = graph.rag_mean_color(img, labels) - new_labels = graph.merge_hierarchical(labels, rag, 10) + new_labels = graph.merge_hierarchical_mean_color(labels, rag, 10) # Two labels assert new_labels.max() == 1