diff --git a/skimage/graph/graph_merge.py b/skimage/graph/graph_merge.py index d7499066..68be4392 100644 --- a/skimage/graph/graph_merge.py +++ b/skimage/graph/graph_merge.py @@ -2,10 +2,10 @@ import numpy as np import heapq -def _hmerge(rag, x, y, n): +def _hmerge_mean_color(graph, src, dst, n): """Callback to handle merging nodes by recomputing mean color. - The method expects that the mean color of `y` is already computed. + The method expects that the mean color of `dst` is already computed. Parameters ---------- @@ -19,20 +19,14 @@ def _hmerge(rag, x, y, n): Returns ------- weight : float - The absolute difference of the mean color between node `y` and `n`. + The absolute difference of the mean color between node `dst` and `n`. """ - diff = rag.node[y]['mean color'] - rag.node[n]['mean color'] - + diff = graph.node[dst]['mean color'] - graph.node[n]['mean color'] diff = np.linalg.norm(diff) - - #rag.add_edge(y,n) - #rag[y][n]['valid'] = True - #heapq.heappush(heap,(diff, y , n, data)) - return diff -def merge_hierarchical(labels, rag, thresh): +def merge_hierarchical(labels, rag, thresh, in_place=True): """Perform hierarchical merging of a RAG. Given an image's labels and its RAG, the method merges the similar nodes @@ -48,6 +42,8 @@ def merge_hierarchical(labels, rag, thresh): The threshold. Regions connected by an edges with smaller wegiht than `thresh` are merged. A high value of `thresh` would mean that a lot of regions are merged, and the output will contain fewer regions. + in_place : bool, optional + If set, the RAG is modified in place. Returns ------- @@ -63,22 +59,23 @@ def merge_hierarchical(labels, rag, thresh): >>> new_labels = graph.merge_hierarchical(labels, rag, 40) """ min_wt = 0 - + + if not in_place: + rag = rag.copy() + edge_heap = [] - for x,y,data in rag.edges_iter(data=True): + for x, y, data in rag.edges_iter(data=True): if x != y: + # Validate all edges and push them in heap data['valid'] = True wt = data['weight'] heapq.heappush(edge_heap, (wt, x, y, data)) - - + while min_wt < thresh: - #valid_edges = ((x, y, d) for x, y, d in rag.edges(data=True) if x != y) - #x, y, d = min(valid_edges, key=lambda x: x[2]['weight']) - #min_wt = d['weight'] - min_wt,x,y,data = heapq.heappop(edge_heap) - print min_wt + min_wt, x, y, data = heapq.heappop(edge_heap) + + # Ensure popped edge is valid, if not, the edge is discarded if min_wt < thresh and data['valid']: total_color = (rag.node[y]['total color'] + rag.node[x]['total color']) @@ -86,25 +83,38 @@ def merge_hierarchical(labels, rag, thresh): rag.node[y]['total color'] = total_color rag.node[y]['pixel count'] = n_pixels rag.node[y]['mean color'] = total_color / n_pixels - - - #print x,y - for n in rag.neighbors(x): - rag[x][n]['valid'] = False - - for n in rag.neighbors(y): - rag[y][n]['valid'] = False - rag.merge_nodes(x, y, _hmerge) - - + # This will invalidate all the below edges in the heap + for n in rag.neighbors(x): + rag[x][n]['valid'] = False + for n in rag.neighbors(y): - if n!= y: - rag[y][n]['valid'] = True + rag[y][n]['valid'] = False + + rag.merge_nodes(x, y, _hmerge_mean_color) + for n in rag.neighbors(y): + if n != y: + # networkx updates data dictionary if edge exists + # this would mean we have to reposition these edges in + # heap if their weight is updated. + # instead we invalidate them + + # invalidates the edge in the heap, if it all it exists + data = rag[y][n] + data['valid'] = False + + # allocate a new dictionary for the edge + data_copy = data.copy() + rag[y][n] = data_copy + rag[n][y] = data_copy + + # validate this edge + rag.add_edge(y, n, valid=True) + + # push the new validated edge in the heap, this will be + # moved to its proper position wt = rag[y][n]['weight'] heapq.heappush(edge_heap, (wt, y, n, rag[y][n])) - - arr = np.arange(labels.max() + 1) for ix, (n, d) in enumerate(rag.nodes_iter(data=True)):