mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 19:48:43 +08:00
handled the function with callbacks
This commit is contained in:
@@ -5,8 +5,8 @@ RAG Merging
|
||||
|
||||
This example constructs a Region Adjacency Graph (RAG) and progressively merges
|
||||
regions that are similar in color. Merging two adjacent regions produces
|
||||
a new regions with all the pixels from the merged regions. Regions are merged
|
||||
until no highly similar regions remain.
|
||||
a new region with all the pixels from the merged regions. Regions are merged
|
||||
until no highly similar region pairs remain.
|
||||
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ from skimage import graph, data, io, segmentation, color
|
||||
img = data.coffee()
|
||||
labels = segmentation.slic(img, compactness=30, n_segments=400)
|
||||
g = graph.rag_mean_color(img, labels)
|
||||
labels2 = graph.merge_hierarchical(labels, g, 40)
|
||||
labels2 = graph.merge_hierarchical_mean_color(labels, g, 40)
|
||||
g2 = graph.rag_mean_color(img, labels2)
|
||||
|
||||
out = color.label2rgb(labels2, img, kind='avg')
|
||||
|
||||
@@ -2,7 +2,7 @@ from .spath import shortest_path
|
||||
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
|
||||
from .graph_cut import cut_threshold, cut_normalized
|
||||
from .rag import rag_mean_color, RAG, draw_rag
|
||||
from .graph_merge import merge_hierarchical
|
||||
from .graph_merge import merge_hierarchical, merge_hierarchical_mean_color
|
||||
ncut = cut_normalized
|
||||
|
||||
|
||||
@@ -18,4 +18,5 @@ __all__ = ['shortest_path',
|
||||
'ncut',
|
||||
'draw_rag',
|
||||
'merge_hierarchical',
|
||||
'merge_hierarchical_mean_color',
|
||||
'RAG']
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import heapq
|
||||
|
||||
|
||||
def _hmerge_mean_color(graph, src, dst, n):
|
||||
def _weight_mean_color(graph, src, dst, n):
|
||||
"""Callback to handle merging nodes by recomputing mean color.
|
||||
|
||||
The method expects that the mean color of `dst` is already computed.
|
||||
@@ -26,6 +26,24 @@ def _hmerge_mean_color(graph, src, dst, n):
|
||||
return diff
|
||||
|
||||
|
||||
def _pre_merge_mean_color(graph, src, dst):
|
||||
"""Callback called before merging two nodes of a mean color distance graph.
|
||||
|
||||
This method computes the mean color of `dst`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : RAG
|
||||
The graph under consideration.
|
||||
src, dst : int
|
||||
The vertices in `graph` to be merged.
|
||||
"""
|
||||
graph.node[dst]['total color'] += graph.node[src]['total color']
|
||||
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
|
||||
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
|
||||
graph.node[dst]['pixel count'])
|
||||
|
||||
|
||||
def _revalidate_node_edges(rag, node, heap_list):
|
||||
"""Handles validation and invalidation of edges incident to a node.
|
||||
|
||||
@@ -61,8 +79,10 @@ def _revalidate_node_edges(rag, node, heap_list):
|
||||
heapq.heappush(heap_list, heap_item)
|
||||
|
||||
|
||||
def merge_hierarchical(labels, rag, thresh, in_place=True):
|
||||
"""Perform hierarchical merging of a RAG.
|
||||
def merge_hierarchical_mean_color(labels, rag, thresh, in_place=True):
|
||||
return merge_hierarchical(labels, rag, thresh, in_place,
|
||||
_pre_merge_mean_color, _weight_mean_color)
|
||||
"""Perform hierarchical merging of a color distance RAG.
|
||||
|
||||
Greedily merges the most similar pair of nodes until no edges lower than
|
||||
`thresh` remain.
|
||||
@@ -79,18 +99,48 @@ def merge_hierarchical(labels, rag, thresh, in_place=True):
|
||||
in_place : bool, optional
|
||||
If set, the RAG is modified in place.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray
|
||||
The new labeled array.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from skimage import data, graph, segmentation
|
||||
>>> img = data.coffee()
|
||||
>>> labels = segmentation.slic(img)
|
||||
>>> rag = graph.rag_mean_color(img, labels)
|
||||
>>> new_labels = graph.merge_hierarchical(labels, rag, 40)
|
||||
>>> new_labels = graph.merge_hierarchical_mean_color(labels, rag, 40)
|
||||
"""
|
||||
|
||||
|
||||
def merge_hierarchical(labels, rag, thresh, in_place, pre_merge_func,
|
||||
weight_func):
|
||||
"""Perform hierarchical merging of a RAG.
|
||||
|
||||
Greedily merges the most similar pair of nodes until no edges lower than
|
||||
`thresh` remain.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
labels : ndarray
|
||||
The array of labels.
|
||||
rag : RAG
|
||||
The Region Adjacency Graph.
|
||||
thresh : float
|
||||
Regions connected by an edge with weight smaller than `thresh` are
|
||||
merged.
|
||||
in_place : bool, optional
|
||||
If set, the RAG is modified in place.
|
||||
pre_merge_func : callable
|
||||
This function is called before merging two nodes. For the RAG `graph`
|
||||
while merging `src` and `dst`, it is called as follows
|
||||
``pre_merge_func(graph, src, dst)``.
|
||||
weight_func : callable
|
||||
The function to compute the new weights of the nodes adjacent to the
|
||||
merged node. This is directly supplied as the argument `weight_func`
|
||||
to `merge_nodes`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray
|
||||
The new labeled array.
|
||||
|
||||
"""
|
||||
if not in_place:
|
||||
rag = rag.copy()
|
||||
@@ -110,16 +160,12 @@ def merge_hierarchical(labels, rag, thresh, in_place=True):
|
||||
|
||||
# Ensure popped edge is valid, if not, the edge is discarded
|
||||
if valid:
|
||||
rag.node[dst]['total color'] += rag.node[src]['total color']
|
||||
rag.node[dst]['pixel count'] += rag.node[src]['pixel count']
|
||||
rag.node[dst]['mean color'] = (rag.node[dst]['total color'] /
|
||||
rag.node[dst]['pixel count'])
|
||||
|
||||
_pre_merge_mean_color(rag, src, dst)
|
||||
# Invalidate all neigbors of `src` before its deleted
|
||||
for n in rag.neighbors(src):
|
||||
rag[src][n]['heap item'][3] = False
|
||||
|
||||
rag.merge_nodes(src, dst, _hmerge_mean_color)
|
||||
rag.merge_nodes(src, dst, _weight_mean_color)
|
||||
_revalidate_node_edges(rag, dst, edge_heap)
|
||||
|
||||
arr = np.arange(labels.max() + 1)
|
||||
|
||||
@@ -124,7 +124,7 @@ def test_merge_hierarchical():
|
||||
labels[50:, 50:] = 3
|
||||
|
||||
rag = graph.rag_mean_color(img, labels)
|
||||
new_labels = graph.merge_hierarchical(labels, rag, 10)
|
||||
new_labels = graph.merge_hierarchical_mean_color(labels, rag, 10)
|
||||
|
||||
# Two labels
|
||||
assert new_labels.max() == 1
|
||||
|
||||
Reference in New Issue
Block a user