complete heap implementation

This commit is contained in:
Vighnesh Birodkar
2014-08-16 16:03:50 +05:30
parent bb11caa1e1
commit d0e2552863
+45 -35
View File
@@ -2,10 +2,10 @@ import numpy as np
import heapq
def _hmerge(rag, x, y, n):
def _hmerge_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `y` is already computed.
The method expects that the mean color of `dst` is already computed.
Parameters
----------
@@ -19,20 +19,14 @@ def _hmerge(rag, x, y, n):
Returns
-------
weight : float
The absolute difference of the mean color between node `y` and `n`.
The absolute difference of the mean color between node `dst` and `n`.
"""
diff = rag.node[y]['mean color'] - rag.node[n]['mean color']
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
diff = np.linalg.norm(diff)
#rag.add_edge(y,n)
#rag[y][n]['valid'] = True
#heapq.heappush(heap,(diff, y , n, data))
return diff
def merge_hierarchical(labels, rag, thresh):
def merge_hierarchical(labels, rag, thresh, in_place=True):
"""Perform hierarchical merging of a RAG.
Given an image's labels and its RAG, the method merges the similar nodes
@@ -48,6 +42,8 @@ def merge_hierarchical(labels, rag, thresh):
The threshold. Regions connected by an edges with smaller wegiht than
`thresh` are merged. A high value of `thresh` would mean that a lot of
regions are merged, and the output will contain fewer regions.
in_place : bool, optional
If set, the RAG is modified in place.
Returns
-------
@@ -63,22 +59,23 @@ def merge_hierarchical(labels, rag, thresh):
>>> new_labels = graph.merge_hierarchical(labels, rag, 40)
"""
min_wt = 0
if not in_place:
rag = rag.copy()
edge_heap = []
for x,y,data in rag.edges_iter(data=True):
for x, y, data in rag.edges_iter(data=True):
if x != y:
# Validate all edges and push them in heap
data['valid'] = True
wt = data['weight']
heapq.heappush(edge_heap, (wt, x, y, data))
while min_wt < thresh:
#valid_edges = ((x, y, d) for x, y, d in rag.edges(data=True) if x != y)
#x, y, d = min(valid_edges, key=lambda x: x[2]['weight'])
#min_wt = d['weight']
min_wt,x,y,data = heapq.heappop(edge_heap)
print min_wt
min_wt, x, y, data = heapq.heappop(edge_heap)
# Ensure popped edge is valid, if not, the edge is discarded
if min_wt < thresh and data['valid']:
total_color = (rag.node[y]['total color'] +
rag.node[x]['total color'])
@@ -86,25 +83,38 @@ def merge_hierarchical(labels, rag, thresh):
rag.node[y]['total color'] = total_color
rag.node[y]['pixel count'] = n_pixels
rag.node[y]['mean color'] = total_color / n_pixels
#print x,y
for n in rag.neighbors(x):
rag[x][n]['valid'] = False
for n in rag.neighbors(y):
rag[y][n]['valid'] = False
rag.merge_nodes(x, y, _hmerge)
# This will invalidate all the below edges in the heap
for n in rag.neighbors(x):
rag[x][n]['valid'] = False
for n in rag.neighbors(y):
if n!= y:
rag[y][n]['valid'] = True
rag[y][n]['valid'] = False
rag.merge_nodes(x, y, _hmerge_mean_color)
for n in rag.neighbors(y):
if n != y:
# networkx updates data dictionary if edge exists
# this would mean we have to reposition these edges in
# heap if their weight is updated.
# instead we invalidate them
# invalidates the edge in the heap, if it all it exists
data = rag[y][n]
data['valid'] = False
# allocate a new dictionary for the edge
data_copy = data.copy()
rag[y][n] = data_copy
rag[n][y] = data_copy
# validate this edge
rag.add_edge(y, n, valid=True)
# push the new validated edge in the heap, this will be
# moved to its proper position
wt = rag[y][n]['weight']
heapq.heappush(edge_heap, (wt, y, n, rag[y][n]))
arr = np.arange(labels.max() + 1)
for ix, (n, d) in enumerate(rag.nodes_iter(data=True)):