Merge pull request #1100 from vighneshbirodkar/ha

Hierarchical Merging
This commit is contained in:
Juan Nunez-Iglesias
2015-01-30 16:05:58 +11:00
5 changed files with 272 additions and 0 deletions
+74
View File
@@ -0,0 +1,74 @@
"""
===========
RAG Merging
===========
This example constructs a Region Adjacency Graph (RAG) and progressively merges
regions that are similar in color. Merging two adjacent regions produces
a new region with all the pixels from the merged regions. Regions are merged
until no highly similar region pairs remain.
"""
from skimage import graph, data, io, segmentation, color
import numpy as np
def _weight_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `dst` is already computed.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
n : int
A neighbor of `src` or `dst` or both.
Returns
-------
weight : float
The absolute difference of the mean color between node `dst` and `n`.
"""
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
diff = np.linalg.norm(diff)
return diff
def merge_mean_color(graph, src, dst):
"""Callback called before merging two nodes of a mean color distance graph.
This method computes the mean color of `dst`.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
"""
graph.node[dst]['total color'] += graph.node[src]['total color']
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
graph.node[dst]['pixel count'])
img = data.coffee()
labels = segmentation.slic(img, compactness=30, n_segments=400)
g = graph.rag_mean_color(img, labels)
labels2 = graph.merge_hierarchical(labels, g, thresh=40, rag_copy=False,
in_place_merge=True,
merge_func=merge_mean_color,
weight_func=_weight_mean_color)
g2 = graph.rag_mean_color(img, labels2)
out = color.label2rgb(labels2, img, kind='avg')
out = segmentation.mark_boundaries(out, labels2, (0, 0, 0))
io.imshow(out)
io.show()
+2
View File
@@ -2,6 +2,7 @@ from .spath import shortest_path
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
from .graph_cut import cut_threshold, cut_normalized
from .rag import rag_mean_color, RAG, draw_rag
from .graph_merge import merge_hierarchical
ncut = cut_normalized
@@ -16,4 +17,5 @@ __all__ = ['shortest_path',
'cut_normalized',
'ncut',
'draw_rag',
'merge_hierarchical',
'RAG']
+137
View File
@@ -0,0 +1,137 @@
import numpy as np
import heapq
def _revalidate_node_edges(rag, node, heap_list):
"""Handles validation and invalidation of edges incident to a node.
This function invalidates all existing edges incident on `node` and inserts
new items in `heap_list` updated with the valid weights.
rag : RAG
The Region Adjacency Graph.
node : int
The id of the node whose incident edges are to be validated/invalidated
.
heap_list : list
The list containing the existing heap of edges.
"""
# networkx updates data dictionary if edge exists
# this would mean we have to reposition these edges in
# heap if their weight is updated.
# instead we invalidate them
for nbr in rag.neighbors(node):
data = rag[node][nbr]
try:
# invalidate edges incident on `dst`, they have new weights
data['heap item'][3] = False
_invalidate_edge(rag, node, nbr)
except KeyError:
# will handle the case where the edge did not exist in the existing
# graph
pass
wt = data['weight']
heap_item = [wt, node, nbr, True]
data['heap item'] = heap_item
heapq.heappush(heap_list, heap_item)
def _rename_node(graph, node_id, copy_id):
""" Rename `node_id` in `graph` to `copy_id`. """
graph._add_node_silent(copy_id)
graph.node[copy_id] = graph.node[node_id]
for nbr in graph.neighbors(node_id):
wt = graph[node_id][nbr]['weight']
graph.add_edge(nbr, copy_id, {'weight': wt})
graph.remove_node(node_id)
def _invalidate_edge(graph, n1, n2):
""" Invalidates the edge (n1, n2) in the heap. """
graph[n1][n2]['heap item'][3] = False
def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge,
merge_func, weight_func):
"""Perform hierarchical merging of a RAG.
Greedily merges the most similar pair of nodes until no edges lower than
`thresh` remain.
Parameters
----------
labels : ndarray
The array of labels.
rag : RAG
The Region Adjacency Graph.
thresh : float
Regions connected by an edge with weight smaller than `thresh` are
merged.
rag_copy : bool
If set, the RAG copied before modifying.
in_place_merge : bool
If set, the nodes are merged in place. Otherwise, a new node is
created for each merge..
merge_func : callable
This function is called before merging two nodes. For the RAG `graph`
while merging `src` and `dst`, it is called as follows
``merge_func(graph, src, dst)``.
weight_func : callable
The function to compute the new weights of the nodes adjacent to the
merged node. This is directly supplied as the argument `weight_func`
to `merge_nodes`.
Returns
-------
out : ndarray
The new labeled array.
"""
if rag_copy:
rag = rag.copy()
edge_heap = []
for n1, n2, data in rag.edges_iter(data=True):
# Push a valid edge in the heap
wt = data['weight']
heap_item = [wt, n1, n2, True]
heapq.heappush(edge_heap, heap_item)
# Reference to the heap item in the graph
data['heap item'] = heap_item
while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
_, n1, n2, valid = heapq.heappop(edge_heap)
# Ensure popped edge is valid, if not, the edge is discarded
if valid:
# Invalidate all neigbors of `src` before its deleted
for nbr in rag.neighbors(n1):
_invalidate_edge(rag, n1, nbr)
for nbr in rag.neighbors(n2):
_invalidate_edge(rag, n2, nbr)
if not in_place_merge:
next_id = rag.next_id()
_rename_node(rag, n2, next_id)
src, dst = n1, next_id
else:
src, dst = n1, n2
merge_func(rag, src, dst)
new_id = rag.merge_nodes(src, dst, weight_func)
_revalidate_node_edges(rag, new_id, edge_heap)
label_map = np.arange(labels.max() + 1)
for ix, (n, d) in enumerate(rag.nodes_iter(data=True)):
for label in d['labels']:
label_map[label] = ix
return label_map[labels]
+8
View File
@@ -164,6 +164,14 @@ class RAG(nx.Graph):
"""
return self.max_id + 1
def _add_node_silent(self, n):
"""Add node `n` without updating the maximum node id.
This is a convenience method used internally.
.. seealso:: :func:`networkx.Graph.add_node`."""
super(RAG, self).add_node(n)
def _add_edge_filter(values, graph):
"""Create edge in `g` between the first element of `values` and the rest.
+51
View File
@@ -108,3 +108,54 @@ def test_rag_error():
labels[5:, :] = 1
testing.assert_raises(ValueError, graph.rag_mean_color, img, labels,
2, 'non existant mode')
def _weight_mean_color(graph, src, dst, n):
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
diff = np.linalg.norm(diff)
return diff
def _pre_merge_mean_color(graph, src, dst):
graph.node[dst]['total color'] += graph.node[src]['total color']
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
graph.node[dst]['pixel count'])
def merge_hierarchical_mean_color(labels, rag, thresh, rag_copy=True,
in_place_merge=False):
return graph.merge_hierarchical(labels, rag, thresh, rag_copy,
in_place_merge, _pre_merge_mean_color,
_weight_mean_color)
@skipif(not is_installed('networkx'))
def test_rag_hierarchical():
img = np.zeros((8, 8, 3), dtype='uint8')
labels = np.zeros((8, 8), dtype='uint8')
img[:, :, :] = 31
labels[:, :] = 1
img[0:4, 0:4, :] = 10, 10, 10
labels[0:4, 0:4] = 2
img[4:, 0:4, :] = 20, 20, 20
labels[4:, 0:4] = 3
g = graph.rag_mean_color(img, labels)
g2 = g.copy()
thresh = 20 # more than 11*sqrt(3) but less than
result = merge_hierarchical_mean_color(labels, g, thresh)
assert(np.all(result[:, :4] == result[0, 0]))
assert(np.all(result[:, 4:] == result[-1, -1]))
result = merge_hierarchical_mean_color(labels, g2, thresh,
in_place_merge=True)
assert(np.all(result[:, :4] == result[0, 0]))
assert(np.all(result[:, 4:] == result[-1, -1]))
result = graph.cut_threshold(labels, g, thresh)
assert np.all(result == result[0, 0])