handled the function with callbacks

This commit is contained in:
Vighnesh Birodkar
2014-08-27 21:09:25 +05:30
parent e41b34314a
commit 4c4d7108e1
4 changed files with 67 additions and 20 deletions
+3 -3
View File
@@ -5,8 +5,8 @@ RAG Merging
This example constructs a Region Adjacency Graph (RAG) and progressively merges
regions that are similar in color. Merging two adjacent regions produces
a new regions with all the pixels from the merged regions. Regions are merged
until no highly similar regions remain.
a new region with all the pixels from the merged regions. Regions are merged
until no highly similar region pairs remain.
"""
@@ -16,7 +16,7 @@ from skimage import graph, data, io, segmentation, color
img = data.coffee()
labels = segmentation.slic(img, compactness=30, n_segments=400)
g = graph.rag_mean_color(img, labels)
labels2 = graph.merge_hierarchical(labels, g, 40)
labels2 = graph.merge_hierarchical_mean_color(labels, g, 40)
g2 = graph.rag_mean_color(img, labels2)
out = color.label2rgb(labels2, img, kind='avg')
+2 -1
View File
@@ -2,7 +2,7 @@ from .spath import shortest_path
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
from .graph_cut import cut_threshold, cut_normalized
from .rag import rag_mean_color, RAG, draw_rag
from .graph_merge import merge_hierarchical
from .graph_merge import merge_hierarchical, merge_hierarchical_mean_color
ncut = cut_normalized
@@ -18,4 +18,5 @@ __all__ = ['shortest_path',
'ncut',
'draw_rag',
'merge_hierarchical',
'merge_hierarchical_mean_color',
'RAG']
+61 -15
View File
@@ -2,7 +2,7 @@ import numpy as np
import heapq
def _hmerge_mean_color(graph, src, dst, n):
def _weight_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `dst` is already computed.
@@ -26,6 +26,24 @@ def _hmerge_mean_color(graph, src, dst, n):
return diff
def _pre_merge_mean_color(graph, src, dst):
"""Callback called before merging two nodes of a mean color distance graph.
This method computes the mean color of `dst`.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
"""
graph.node[dst]['total color'] += graph.node[src]['total color']
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
graph.node[dst]['pixel count'])
def _revalidate_node_edges(rag, node, heap_list):
"""Handles validation and invalidation of edges incident to a node.
@@ -61,8 +79,10 @@ def _revalidate_node_edges(rag, node, heap_list):
heapq.heappush(heap_list, heap_item)
def merge_hierarchical(labels, rag, thresh, in_place=True):
"""Perform hierarchical merging of a RAG.
def merge_hierarchical_mean_color(labels, rag, thresh, in_place=True):
return merge_hierarchical(labels, rag, thresh, in_place,
_pre_merge_mean_color, _weight_mean_color)
"""Perform hierarchical merging of a color distance RAG.
Greedily merges the most similar pair of nodes until no edges lower than
`thresh` remain.
@@ -79,18 +99,48 @@ def merge_hierarchical(labels, rag, thresh, in_place=True):
in_place : bool, optional
If set, the RAG is modified in place.
Returns
-------
out : ndarray
The new labeled array.
Examples
--------
>>> from skimage import data, graph, segmentation
>>> img = data.coffee()
>>> labels = segmentation.slic(img)
>>> rag = graph.rag_mean_color(img, labels)
>>> new_labels = graph.merge_hierarchical(labels, rag, 40)
>>> new_labels = graph.merge_hierarchical_mean_color(labels, rag, 40)
"""
def merge_hierarchical(labels, rag, thresh, in_place, pre_merge_func,
weight_func):
"""Perform hierarchical merging of a RAG.
Greedily merges the most similar pair of nodes until no edges lower than
`thresh` remain.
Parameters
----------
labels : ndarray
The array of labels.
rag : RAG
The Region Adjacency Graph.
thresh : float
Regions connected by an edge with weight smaller than `thresh` are
merged.
in_place : bool, optional
If set, the RAG is modified in place.
pre_merge_func : callable
This function is called before merging two nodes. For the RAG `graph`
while merging `src` and `dst`, it is called as follows
``pre_merge_func(graph, src, dst)``.
weight_func : callable
The function to compute the new weights of the nodes adjacent to the
merged node. This is directly supplied as the argument `weight_func`
to `merge_nodes`.
Returns
-------
out : ndarray
The new labeled array.
"""
if not in_place:
rag = rag.copy()
@@ -110,16 +160,12 @@ def merge_hierarchical(labels, rag, thresh, in_place=True):
# Ensure popped edge is valid, if not, the edge is discarded
if valid:
rag.node[dst]['total color'] += rag.node[src]['total color']
rag.node[dst]['pixel count'] += rag.node[src]['pixel count']
rag.node[dst]['mean color'] = (rag.node[dst]['total color'] /
rag.node[dst]['pixel count'])
_pre_merge_mean_color(rag, src, dst)
# Invalidate all neigbors of `src` before its deleted
for n in rag.neighbors(src):
rag[src][n]['heap item'][3] = False
rag.merge_nodes(src, dst, _hmerge_mean_color)
rag.merge_nodes(src, dst, _weight_mean_color)
_revalidate_node_edges(rag, dst, edge_heap)
arr = np.arange(labels.max() + 1)
+1 -1
View File
@@ -124,7 +124,7 @@ def test_merge_hierarchical():
labels[50:, 50:] = 3
rag = graph.rag_mean_color(img, labels)
new_labels = graph.merge_hierarchical(labels, rag, 10)
new_labels = graph.merge_hierarchical_mean_color(labels, rag, 10)
# Two labels
assert new_labels.max() == 1