diff --git a/doc/examples/plot_rag.py b/doc/examples/plot_rag.py index b26a266d..f360af52 100644 --- a/doc/examples/plot_rag.py +++ b/doc/examples/plot_rag.py @@ -75,7 +75,7 @@ display(g, "Original Graph") g.merge_nodes(1, 3) display(g, "Merged with default (min)") -gc.merge_nodes(1, 3, weight_func=max_edge) -display(gc, "Merged with max") +gc.merge_nodes(1, 3, weight_func=max_edge, in_place=False) +display(gc, "Merged with max without in_place") plt.show() diff --git a/skimage/graph/rag.py b/skimage/graph/rag.py index 5995fb5d..81b3c0d6 100644 --- a/skimage/graph/rag.py +++ b/skimage/graph/rag.py @@ -59,9 +59,18 @@ class RAG(nx.Graph): `networx.Graph `_ """ - def merge_nodes(self, src, dst, weight_func=min_weight, extra_arguments=[], - extra_keywords={}): - """Merge node `src` into `dst`. + def __init__(self, data=None, **attr): + + super(RAG, self).__init__(data, **attr) + try: + self.max_id = max(self.nodes_iter()) + except ValueError: + # Empty sequence + self.max_id = 0 + + def merge_nodes(self, src, dst, weight_func=min_weight, in_place=True, + extra_arguments=[], extra_keywords={}): + """Merge node `src` and `dst`. The new combined node is adjacent to all the neighbors of `src` and `dst`. `weight_func` is called to decide the weight of edges @@ -78,25 +87,83 @@ class RAG(nx.Graph): **extra_keywords)`. `src`, `dst` and `n` are IDs of vertices in the RAG object which is in turn a subclass of `networkx.Graph`. + in_place : bool, optional + If set to `True`, the merged node has the id `dst`, else merged + node has a new id which is returned. extra_arguments : sequence, optional The sequence of extra positional arguments passed to `weight_func`. extra_keywords : dictionary, optional The dict of keyword arguments passed to the `weight_func`. + Returns + ------- + id : int + The id of the new node. + + Notes + ----- + If `in_place` is `False` the resulting node has a new id, rather than + `dst`. """ src_nbrs = set(self.neighbors(src)) dst_nbrs = set(self.neighbors(dst)) - neighbors = (src_nbrs & dst_nbrs) - set([src, dst]) + neighbors = (src_nbrs | dst_nbrs) - set([src, dst]) + + if in_place: + new = dst + else: + new = self.next_id() + self.add_node(new) for neighbor in neighbors: - w = weight_func(self, src, dst, neighbor, *extra_arguments, + w = weight_func(self, src, new, neighbor, *extra_arguments, **extra_keywords) - self.add_edge(neighbor, dst, weight=w) + self.add_edge(neighbor, new, weight=w) - self.node[dst]['labels'] += self.node[src]['labels'] + self.node[new]['labels'] = (self.node[src]['labels'] + + self.node[dst]['labels']) self.remove_node(src) + if not in_place: + self.remove_node(dst) + + return new + + def add_node(self, n, attr_dict=None, **attr): + """Add node `n` while updating the maximum node id. + + .. seealso:: :func:`networkx.Graph.add_node`.""" + super(RAG, self).add_node(n, attr_dict, **attr) + self.max_id = max(n, self.max_id) + + def add_edge(self, u, v, attr_dict=None, **attr): + """Add an edge between `u` and `v` while updating max node id. + + .. seealso:: :func:`networkx.Graph.add_edge`.""" + super(RAG, self).add_edge(u, v, attr_dict, **attr) + self.max_id = max(u, v, self.max_id) + + def copy(self): + """Copy the graph with its max node id. + + .. seealso:: :func:`networkx.Graph.copy`.""" + g = super(RAG, self).copy() + g.max_id = self.max_id + return g + + def next_id(self): + """Returns the `id` for the new node to be inserted. + + The current implementation returns one more than the maximum `id`. + + Returns + ------- + id : int + The `id` of the new node to be inserted. + """ + return self.max_id + 1 + def _add_edge_filter(values, graph): """Create edge in `g` between the first element of `values` and the rest. diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 26f5a85c..35eb5c38 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -43,8 +43,8 @@ def test_rag_merge(): g.merge_nodes(1, 4) g.merge_nodes(2, 3) - g.merge_nodes(3, 4) - assert sorted(g.node[4]['labels']) == list(range(5)) + n = g.merge_nodes(3, 4, in_place=False) + assert sorted(g.node[n]['labels']) == list(range(5)) assert g.edges() == []