diff --git a/doc/examples/plot_rag_merge.py b/doc/examples/plot_rag_merge.py new file mode 100644 index 00000000..d454a096 --- /dev/null +++ b/doc/examples/plot_rag_merge.py @@ -0,0 +1,74 @@ +""" +=========== +RAG Merging +=========== + +This example constructs a Region Adjacency Graph (RAG) and progressively merges +regions that are similar in color. Merging two adjacent regions produces +a new region with all the pixels from the merged regions. Regions are merged +until no highly similar region pairs remain. + +""" + +from skimage import graph, data, io, segmentation, color +import numpy as np + + +def _weight_mean_color(graph, src, dst, n): + """Callback to handle merging nodes by recomputing mean color. + + The method expects that the mean color of `dst` is already computed. + + Parameters + ---------- + graph : RAG + The graph under consideration. + src, dst : int + The vertices in `graph` to be merged. + n : int + A neighbor of `src` or `dst` or both. + + Returns + ------- + weight : float + The absolute difference of the mean color between node `dst` and `n`. + """ + + diff = graph.node[dst]['mean color'] - graph.node[n]['mean color'] + diff = np.linalg.norm(diff) + return diff + + +def merge_mean_color(graph, src, dst): + """Callback called before merging two nodes of a mean color distance graph. + + This method computes the mean color of `dst`. + + Parameters + ---------- + graph : RAG + The graph under consideration. + src, dst : int + The vertices in `graph` to be merged. + """ + graph.node[dst]['total color'] += graph.node[src]['total color'] + graph.node[dst]['pixel count'] += graph.node[src]['pixel count'] + graph.node[dst]['mean color'] = (graph.node[dst]['total color'] / + graph.node[dst]['pixel count']) + + +img = data.coffee() +labels = segmentation.slic(img, compactness=30, n_segments=400) +g = graph.rag_mean_color(img, labels) + +labels2 = graph.merge_hierarchical(labels, g, thresh=40, rag_copy=False, + in_place_merge=True, + merge_func=merge_mean_color, + weight_func=_weight_mean_color) + +g2 = graph.rag_mean_color(img, labels2) + +out = color.label2rgb(labels2, img, kind='avg') +out = segmentation.mark_boundaries(out, labels2, (0, 0, 0)) +io.imshow(out) +io.show() diff --git a/skimage/graph/__init__.py b/skimage/graph/__init__.py index 6da4fcc8..f93708c9 100644 --- a/skimage/graph/__init__.py +++ b/skimage/graph/__init__.py @@ -2,6 +2,7 @@ from .spath import shortest_path from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array from .graph_cut import cut_threshold, cut_normalized from .rag import rag_mean_color, RAG, draw_rag +from .graph_merge import merge_hierarchical ncut = cut_normalized @@ -16,4 +17,5 @@ __all__ = ['shortest_path', 'cut_normalized', 'ncut', 'draw_rag', + 'merge_hierarchical', 'RAG'] diff --git a/skimage/graph/graph_merge.py b/skimage/graph/graph_merge.py new file mode 100644 index 00000000..308f7f37 --- /dev/null +++ b/skimage/graph/graph_merge.py @@ -0,0 +1,137 @@ +import numpy as np +import heapq + + +def _revalidate_node_edges(rag, node, heap_list): + """Handles validation and invalidation of edges incident to a node. + + This function invalidates all existing edges incident on `node` and inserts + new items in `heap_list` updated with the valid weights. + + rag : RAG + The Region Adjacency Graph. + node : int + The id of the node whose incident edges are to be validated/invalidated + . + heap_list : list + The list containing the existing heap of edges. + """ + # networkx updates data dictionary if edge exists + # this would mean we have to reposition these edges in + # heap if their weight is updated. + # instead we invalidate them + + for nbr in rag.neighbors(node): + data = rag[node][nbr] + try: + # invalidate edges incident on `dst`, they have new weights + data['heap item'][3] = False + _invalidate_edge(rag, node, nbr) + except KeyError: + # will handle the case where the edge did not exist in the existing + # graph + pass + + wt = data['weight'] + heap_item = [wt, node, nbr, True] + data['heap item'] = heap_item + heapq.heappush(heap_list, heap_item) + + +def _rename_node(graph, node_id, copy_id): + """ Rename `node_id` in `graph` to `copy_id`. """ + + graph._add_node_silent(copy_id) + graph.node[copy_id] = graph.node[node_id] + + for nbr in graph.neighbors(node_id): + wt = graph[node_id][nbr]['weight'] + graph.add_edge(nbr, copy_id, {'weight': wt}) + + graph.remove_node(node_id) + + +def _invalidate_edge(graph, n1, n2): + """ Invalidates the edge (n1, n2) in the heap. """ + graph[n1][n2]['heap item'][3] = False + + +def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge, + merge_func, weight_func): + """Perform hierarchical merging of a RAG. + + Greedily merges the most similar pair of nodes until no edges lower than + `thresh` remain. + + Parameters + ---------- + labels : ndarray + The array of labels. + rag : RAG + The Region Adjacency Graph. + thresh : float + Regions connected by an edge with weight smaller than `thresh` are + merged. + rag_copy : bool + If set, the RAG copied before modifying. + in_place_merge : bool + If set, the nodes are merged in place. Otherwise, a new node is + created for each merge.. + merge_func : callable + This function is called before merging two nodes. For the RAG `graph` + while merging `src` and `dst`, it is called as follows + ``merge_func(graph, src, dst)``. + weight_func : callable + The function to compute the new weights of the nodes adjacent to the + merged node. This is directly supplied as the argument `weight_func` + to `merge_nodes`. + + Returns + ------- + out : ndarray + The new labeled array. + + """ + if rag_copy: + rag = rag.copy() + + edge_heap = [] + for n1, n2, data in rag.edges_iter(data=True): + # Push a valid edge in the heap + wt = data['weight'] + heap_item = [wt, n1, n2, True] + heapq.heappush(edge_heap, heap_item) + + # Reference to the heap item in the graph + data['heap item'] = heap_item + + while len(edge_heap) > 0 and edge_heap[0][0] < thresh: + _, n1, n2, valid = heapq.heappop(edge_heap) + + # Ensure popped edge is valid, if not, the edge is discarded + if valid: + # Invalidate all neigbors of `src` before its deleted + + for nbr in rag.neighbors(n1): + _invalidate_edge(rag, n1, nbr) + + for nbr in rag.neighbors(n2): + _invalidate_edge(rag, n2, nbr) + + if not in_place_merge: + next_id = rag.next_id() + _rename_node(rag, n2, next_id) + src, dst = n1, next_id + else: + src, dst = n1, n2 + + merge_func(rag, src, dst) + new_id = rag.merge_nodes(src, dst, weight_func) + _revalidate_node_edges(rag, new_id, edge_heap) + + label_map = np.arange(labels.max() + 1) + for ix, (n, d) in enumerate(rag.nodes_iter(data=True)): + for label in d['labels']: + label_map[label] = ix + + return label_map[labels] diff --git a/skimage/graph/rag.py b/skimage/graph/rag.py index 5f148c5e..120e9e91 100644 --- a/skimage/graph/rag.py +++ b/skimage/graph/rag.py @@ -164,6 +164,14 @@ class RAG(nx.Graph): """ return self.max_id + 1 + def _add_node_silent(self, n): + """Add node `n` without updating the maximum node id. + + This is a convenience method used internally. + + .. seealso:: :func:`networkx.Graph.add_node`.""" + super(RAG, self).add_node(n) + def _add_edge_filter(values, graph): """Create edge in `g` between the first element of `values` and the rest. diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 35eb5c38..cfa49cdf 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -108,3 +108,54 @@ def test_rag_error(): labels[5:, :] = 1 testing.assert_raises(ValueError, graph.rag_mean_color, img, labels, 2, 'non existant mode') + + +def _weight_mean_color(graph, src, dst, n): + diff = graph.node[dst]['mean color'] - graph.node[n]['mean color'] + diff = np.linalg.norm(diff) + return diff + + +def _pre_merge_mean_color(graph, src, dst): + graph.node[dst]['total color'] += graph.node[src]['total color'] + graph.node[dst]['pixel count'] += graph.node[src]['pixel count'] + graph.node[dst]['mean color'] = (graph.node[dst]['total color'] / + graph.node[dst]['pixel count']) + + +def merge_hierarchical_mean_color(labels, rag, thresh, rag_copy=True, + in_place_merge=False): + return graph.merge_hierarchical(labels, rag, thresh, rag_copy, + in_place_merge, _pre_merge_mean_color, + _weight_mean_color) + + +@skipif(not is_installed('networkx')) +def test_rag_hierarchical(): + img = np.zeros((8, 8, 3), dtype='uint8') + labels = np.zeros((8, 8), dtype='uint8') + + img[:, :, :] = 31 + labels[:, :] = 1 + + img[0:4, 0:4, :] = 10, 10, 10 + labels[0:4, 0:4] = 2 + + img[4:, 0:4, :] = 20, 20, 20 + labels[4:, 0:4] = 3 + + g = graph.rag_mean_color(img, labels) + g2 = g.copy() + thresh = 20 # more than 11*sqrt(3) but less than + + result = merge_hierarchical_mean_color(labels, g, thresh) + assert(np.all(result[:, :4] == result[0, 0])) + assert(np.all(result[:, 4:] == result[-1, -1])) + + result = merge_hierarchical_mean_color(labels, g2, thresh, + in_place_merge=True) + assert(np.all(result[:, :4] == result[0, 0])) + assert(np.all(result[:, 4:] == result[-1, -1])) + + result = graph.cut_threshold(labels, g, thresh) + assert np.all(result == result[0, 0])