mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-29 17:37:20 +08:00
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
===========
|
||||
RAG Merging
|
||||
===========
|
||||
|
||||
This example constructs a Region Adjacency Graph (RAG) and progressively merges
|
||||
regions that are similar in color. Merging two adjacent regions produces
|
||||
a new region with all the pixels from the merged regions. Regions are merged
|
||||
until no highly similar region pairs remain.
|
||||
|
||||
"""
|
||||
|
||||
from skimage import graph, data, io, segmentation, color
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _weight_mean_color(graph, src, dst, n):
|
||||
"""Callback to handle merging nodes by recomputing mean color.
|
||||
|
||||
The method expects that the mean color of `dst` is already computed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : RAG
|
||||
The graph under consideration.
|
||||
src, dst : int
|
||||
The vertices in `graph` to be merged.
|
||||
n : int
|
||||
A neighbor of `src` or `dst` or both.
|
||||
|
||||
Returns
|
||||
-------
|
||||
weight : float
|
||||
The absolute difference of the mean color between node `dst` and `n`.
|
||||
"""
|
||||
|
||||
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
|
||||
diff = np.linalg.norm(diff)
|
||||
return diff
|
||||
|
||||
|
||||
def merge_mean_color(graph, src, dst):
|
||||
"""Callback called before merging two nodes of a mean color distance graph.
|
||||
|
||||
This method computes the mean color of `dst`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : RAG
|
||||
The graph under consideration.
|
||||
src, dst : int
|
||||
The vertices in `graph` to be merged.
|
||||
"""
|
||||
graph.node[dst]['total color'] += graph.node[src]['total color']
|
||||
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
|
||||
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
|
||||
graph.node[dst]['pixel count'])
|
||||
|
||||
|
||||
img = data.coffee()
|
||||
labels = segmentation.slic(img, compactness=30, n_segments=400)
|
||||
g = graph.rag_mean_color(img, labels)
|
||||
|
||||
labels2 = graph.merge_hierarchical(labels, g, thresh=40, rag_copy=False,
|
||||
in_place_merge=True,
|
||||
merge_func=merge_mean_color,
|
||||
weight_func=_weight_mean_color)
|
||||
|
||||
g2 = graph.rag_mean_color(img, labels2)
|
||||
|
||||
out = color.label2rgb(labels2, img, kind='avg')
|
||||
out = segmentation.mark_boundaries(out, labels2, (0, 0, 0))
|
||||
io.imshow(out)
|
||||
io.show()
|
||||
@@ -2,6 +2,7 @@ from .spath import shortest_path
|
||||
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
|
||||
from .graph_cut import cut_threshold, cut_normalized
|
||||
from .rag import rag_mean_color, RAG, draw_rag
|
||||
from .graph_merge import merge_hierarchical
|
||||
ncut = cut_normalized
|
||||
|
||||
|
||||
@@ -16,4 +17,5 @@ __all__ = ['shortest_path',
|
||||
'cut_normalized',
|
||||
'ncut',
|
||||
'draw_rag',
|
||||
'merge_hierarchical',
|
||||
'RAG']
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
import numpy as np
|
||||
import heapq
|
||||
|
||||
|
||||
def _revalidate_node_edges(rag, node, heap_list):
|
||||
"""Handles validation and invalidation of edges incident to a node.
|
||||
|
||||
This function invalidates all existing edges incident on `node` and inserts
|
||||
new items in `heap_list` updated with the valid weights.
|
||||
|
||||
rag : RAG
|
||||
The Region Adjacency Graph.
|
||||
node : int
|
||||
The id of the node whose incident edges are to be validated/invalidated
|
||||
.
|
||||
heap_list : list
|
||||
The list containing the existing heap of edges.
|
||||
"""
|
||||
# networkx updates data dictionary if edge exists
|
||||
# this would mean we have to reposition these edges in
|
||||
# heap if their weight is updated.
|
||||
# instead we invalidate them
|
||||
|
||||
for nbr in rag.neighbors(node):
|
||||
data = rag[node][nbr]
|
||||
try:
|
||||
# invalidate edges incident on `dst`, they have new weights
|
||||
data['heap item'][3] = False
|
||||
_invalidate_edge(rag, node, nbr)
|
||||
except KeyError:
|
||||
# will handle the case where the edge did not exist in the existing
|
||||
# graph
|
||||
pass
|
||||
|
||||
wt = data['weight']
|
||||
heap_item = [wt, node, nbr, True]
|
||||
data['heap item'] = heap_item
|
||||
heapq.heappush(heap_list, heap_item)
|
||||
|
||||
|
||||
def _rename_node(graph, node_id, copy_id):
|
||||
""" Rename `node_id` in `graph` to `copy_id`. """
|
||||
|
||||
graph._add_node_silent(copy_id)
|
||||
graph.node[copy_id] = graph.node[node_id]
|
||||
|
||||
for nbr in graph.neighbors(node_id):
|
||||
wt = graph[node_id][nbr]['weight']
|
||||
graph.add_edge(nbr, copy_id, {'weight': wt})
|
||||
|
||||
graph.remove_node(node_id)
|
||||
|
||||
|
||||
def _invalidate_edge(graph, n1, n2):
|
||||
""" Invalidates the edge (n1, n2) in the heap. """
|
||||
graph[n1][n2]['heap item'][3] = False
|
||||
|
||||
|
||||
def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge,
|
||||
merge_func, weight_func):
|
||||
"""Perform hierarchical merging of a RAG.
|
||||
|
||||
Greedily merges the most similar pair of nodes until no edges lower than
|
||||
`thresh` remain.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
labels : ndarray
|
||||
The array of labels.
|
||||
rag : RAG
|
||||
The Region Adjacency Graph.
|
||||
thresh : float
|
||||
Regions connected by an edge with weight smaller than `thresh` are
|
||||
merged.
|
||||
rag_copy : bool
|
||||
If set, the RAG copied before modifying.
|
||||
in_place_merge : bool
|
||||
If set, the nodes are merged in place. Otherwise, a new node is
|
||||
created for each merge..
|
||||
merge_func : callable
|
||||
This function is called before merging two nodes. For the RAG `graph`
|
||||
while merging `src` and `dst`, it is called as follows
|
||||
``merge_func(graph, src, dst)``.
|
||||
weight_func : callable
|
||||
The function to compute the new weights of the nodes adjacent to the
|
||||
merged node. This is directly supplied as the argument `weight_func`
|
||||
to `merge_nodes`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : ndarray
|
||||
The new labeled array.
|
||||
|
||||
"""
|
||||
if rag_copy:
|
||||
rag = rag.copy()
|
||||
|
||||
edge_heap = []
|
||||
for n1, n2, data in rag.edges_iter(data=True):
|
||||
# Push a valid edge in the heap
|
||||
wt = data['weight']
|
||||
heap_item = [wt, n1, n2, True]
|
||||
heapq.heappush(edge_heap, heap_item)
|
||||
|
||||
# Reference to the heap item in the graph
|
||||
data['heap item'] = heap_item
|
||||
|
||||
while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
|
||||
_, n1, n2, valid = heapq.heappop(edge_heap)
|
||||
|
||||
# Ensure popped edge is valid, if not, the edge is discarded
|
||||
if valid:
|
||||
# Invalidate all neigbors of `src` before its deleted
|
||||
|
||||
for nbr in rag.neighbors(n1):
|
||||
_invalidate_edge(rag, n1, nbr)
|
||||
|
||||
for nbr in rag.neighbors(n2):
|
||||
_invalidate_edge(rag, n2, nbr)
|
||||
|
||||
if not in_place_merge:
|
||||
next_id = rag.next_id()
|
||||
_rename_node(rag, n2, next_id)
|
||||
src, dst = n1, next_id
|
||||
else:
|
||||
src, dst = n1, n2
|
||||
|
||||
merge_func(rag, src, dst)
|
||||
new_id = rag.merge_nodes(src, dst, weight_func)
|
||||
_revalidate_node_edges(rag, new_id, edge_heap)
|
||||
|
||||
label_map = np.arange(labels.max() + 1)
|
||||
for ix, (n, d) in enumerate(rag.nodes_iter(data=True)):
|
||||
for label in d['labels']:
|
||||
label_map[label] = ix
|
||||
|
||||
return label_map[labels]
|
||||
@@ -164,6 +164,14 @@ class RAG(nx.Graph):
|
||||
"""
|
||||
return self.max_id + 1
|
||||
|
||||
def _add_node_silent(self, n):
|
||||
"""Add node `n` without updating the maximum node id.
|
||||
|
||||
This is a convenience method used internally.
|
||||
|
||||
.. seealso:: :func:`networkx.Graph.add_node`."""
|
||||
super(RAG, self).add_node(n)
|
||||
|
||||
|
||||
def _add_edge_filter(values, graph):
|
||||
"""Create edge in `g` between the first element of `values` and the rest.
|
||||
|
||||
@@ -108,3 +108,54 @@ def test_rag_error():
|
||||
labels[5:, :] = 1
|
||||
testing.assert_raises(ValueError, graph.rag_mean_color, img, labels,
|
||||
2, 'non existant mode')
|
||||
|
||||
|
||||
def _weight_mean_color(graph, src, dst, n):
|
||||
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
|
||||
diff = np.linalg.norm(diff)
|
||||
return diff
|
||||
|
||||
|
||||
def _pre_merge_mean_color(graph, src, dst):
|
||||
graph.node[dst]['total color'] += graph.node[src]['total color']
|
||||
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
|
||||
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
|
||||
graph.node[dst]['pixel count'])
|
||||
|
||||
|
||||
def merge_hierarchical_mean_color(labels, rag, thresh, rag_copy=True,
|
||||
in_place_merge=False):
|
||||
return graph.merge_hierarchical(labels, rag, thresh, rag_copy,
|
||||
in_place_merge, _pre_merge_mean_color,
|
||||
_weight_mean_color)
|
||||
|
||||
|
||||
@skipif(not is_installed('networkx'))
|
||||
def test_rag_hierarchical():
|
||||
img = np.zeros((8, 8, 3), dtype='uint8')
|
||||
labels = np.zeros((8, 8), dtype='uint8')
|
||||
|
||||
img[:, :, :] = 31
|
||||
labels[:, :] = 1
|
||||
|
||||
img[0:4, 0:4, :] = 10, 10, 10
|
||||
labels[0:4, 0:4] = 2
|
||||
|
||||
img[4:, 0:4, :] = 20, 20, 20
|
||||
labels[4:, 0:4] = 3
|
||||
|
||||
g = graph.rag_mean_color(img, labels)
|
||||
g2 = g.copy()
|
||||
thresh = 20 # more than 11*sqrt(3) but less than
|
||||
|
||||
result = merge_hierarchical_mean_color(labels, g, thresh)
|
||||
assert(np.all(result[:, :4] == result[0, 0]))
|
||||
assert(np.all(result[:, 4:] == result[-1, -1]))
|
||||
|
||||
result = merge_hierarchical_mean_color(labels, g2, thresh,
|
||||
in_place_merge=True)
|
||||
assert(np.all(result[:, :4] == result[0, 0]))
|
||||
assert(np.all(result[:, 4:] == result[-1, -1]))
|
||||
|
||||
result = graph.cut_threshold(labels, g, thresh)
|
||||
assert np.all(result == result[0, 0])
|
||||
|
||||
Reference in New Issue
Block a user