diff --git a/doc/examples/plot_rag.py b/doc/examples/plot_rag.py index b26a266d..f360af52 100644 --- a/doc/examples/plot_rag.py +++ b/doc/examples/plot_rag.py @@ -75,7 +75,7 @@ display(g, "Original Graph") g.merge_nodes(1, 3) display(g, "Merged with default (min)") -gc.merge_nodes(1, 3, weight_func=max_edge) -display(gc, "Merged with max") +gc.merge_nodes(1, 3, weight_func=max_edge, in_place=False) +display(gc, "Merged with max without in_place") plt.show() diff --git a/skimage/graph/rag.py b/skimage/graph/rag.py index 5995fb5d..8f605b1e 100644 --- a/skimage/graph/rag.py +++ b/skimage/graph/rag.py @@ -59,9 +59,9 @@ class RAG(nx.Graph): `networx.Graph `_ """ - def merge_nodes(self, src, dst, weight_func=min_weight, extra_arguments=[], - extra_keywords={}): - """Merge node `src` into `dst`. + def merge_nodes(self, src, dst, weight_func=min_weight, in_place=True, + extra_arguments=[], extra_keywords={}): + """Merge node `src` and `dst`. The new combined node is adjacent to all the neighbors of `src` and `dst`. `weight_func` is called to decide the weight of edges @@ -78,24 +78,44 @@ class RAG(nx.Graph): **extra_keywords)`. `src`, `dst` and `n` are IDs of vertices in the RAG object which is in turn a subclass of `networkx.Graph`. + in_place : bool + If set to `True`, the merged node has the id `dst`, else merged + node has a new id which is returned. extra_arguments : sequence, optional The sequence of extra positional arguments passed to `weight_func`. extra_keywords : dictionary, optional The dict of keyword arguments passed to the `weight_func`. + Returns + ------- + id : int + The id of the new node if `in_place` is `True`. """ src_nbrs = set(self.neighbors(src)) dst_nbrs = set(self.neighbors(dst)) neighbors = (src_nbrs & dst_nbrs) - set([src, dst]) + if not in_place: + new = self.number_of_nodes() + 1 + self.add_node(new) for neighbor in neighbors: w = weight_func(self, src, dst, neighbor, *extra_arguments, **extra_keywords) - self.add_edge(neighbor, dst, weight=w) + if in_place: + self.add_edge(neighbor, dst, weight=w) + else: + self.add_edge(neighbor, new, weight=w) - self.node[dst]['labels'] += self.node[src]['labels'] - self.remove_node(src) + if in_place: + self.node[dst]['labels'] += self.node[src]['labels'] + self.remove_node(src) + else: + self.node[new]['labels'] = (self.node[src]['labels'] + + self.node[dst]['labels']) + self.remove_node(src) + self.remove_node(dst) + return new def _add_edge_filter(values, graph): diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 26f5a85c..04b82ff7 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -41,8 +41,8 @@ def test_rag_merge(): assert gc.edge[1][2]['weight'] == 20 assert gc.edge[2][3]['weight'] == 40 - g.merge_nodes(1, 4) - g.merge_nodes(2, 3) + g.merge_nodes(1, 4, in_place=True) + g.merge_nodes(2, 3, in_place=True) g.merge_nodes(3, 4) assert sorted(g.node[4]['labels']) == list(range(5)) assert g.edges() == []