mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 21:08:24 +08:00
added non in-place merge
This commit is contained in:
@@ -75,7 +75,7 @@ display(g, "Original Graph")
|
||||
g.merge_nodes(1, 3)
|
||||
display(g, "Merged with default (min)")
|
||||
|
||||
gc.merge_nodes(1, 3, weight_func=max_edge)
|
||||
display(gc, "Merged with max")
|
||||
gc.merge_nodes(1, 3, weight_func=max_edge, in_place=False)
|
||||
display(gc, "Merged with max without in_place")
|
||||
|
||||
plt.show()
|
||||
|
||||
+26
-6
@@ -59,9 +59,9 @@ class RAG(nx.Graph):
|
||||
`networx.Graph <http://networkx.github.io/documentation/latest/reference/classes.graph.html>`_
|
||||
"""
|
||||
|
||||
def merge_nodes(self, src, dst, weight_func=min_weight, extra_arguments=[],
|
||||
extra_keywords={}):
|
||||
"""Merge node `src` into `dst`.
|
||||
def merge_nodes(self, src, dst, weight_func=min_weight, in_place=True,
|
||||
extra_arguments=[], extra_keywords={}):
|
||||
"""Merge node `src` and `dst`.
|
||||
|
||||
The new combined node is adjacent to all the neighbors of `src`
|
||||
and `dst`. `weight_func` is called to decide the weight of edges
|
||||
@@ -78,24 +78,44 @@ class RAG(nx.Graph):
|
||||
**extra_keywords)`. `src`, `dst` and `n` are IDs of vertices in the
|
||||
RAG object which is in turn a subclass of
|
||||
`networkx.Graph`.
|
||||
in_place : bool
|
||||
If set to `True`, the merged node has the id `dst`, else merged
|
||||
node has a new id which is returned.
|
||||
extra_arguments : sequence, optional
|
||||
The sequence of extra positional arguments passed to
|
||||
`weight_func`.
|
||||
extra_keywords : dictionary, optional
|
||||
The dict of keyword arguments passed to the `weight_func`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
id : int
|
||||
The id of the new node if `in_place` is `True`.
|
||||
"""
|
||||
src_nbrs = set(self.neighbors(src))
|
||||
dst_nbrs = set(self.neighbors(dst))
|
||||
neighbors = (src_nbrs & dst_nbrs) - set([src, dst])
|
||||
if not in_place:
|
||||
new = self.number_of_nodes() + 1
|
||||
self.add_node(new)
|
||||
|
||||
for neighbor in neighbors:
|
||||
w = weight_func(self, src, dst, neighbor, *extra_arguments,
|
||||
**extra_keywords)
|
||||
self.add_edge(neighbor, dst, weight=w)
|
||||
if in_place:
|
||||
self.add_edge(neighbor, dst, weight=w)
|
||||
else:
|
||||
self.add_edge(neighbor, new, weight=w)
|
||||
|
||||
self.node[dst]['labels'] += self.node[src]['labels']
|
||||
self.remove_node(src)
|
||||
if in_place:
|
||||
self.node[dst]['labels'] += self.node[src]['labels']
|
||||
self.remove_node(src)
|
||||
else:
|
||||
self.node[new]['labels'] = (self.node[src]['labels'] +
|
||||
self.node[dst]['labels'])
|
||||
self.remove_node(src)
|
||||
self.remove_node(dst)
|
||||
return new
|
||||
|
||||
|
||||
def _add_edge_filter(values, graph):
|
||||
|
||||
@@ -41,8 +41,8 @@ def test_rag_merge():
|
||||
assert gc.edge[1][2]['weight'] == 20
|
||||
assert gc.edge[2][3]['weight'] == 40
|
||||
|
||||
g.merge_nodes(1, 4)
|
||||
g.merge_nodes(2, 3)
|
||||
g.merge_nodes(1, 4, in_place=True)
|
||||
g.merge_nodes(2, 3, in_place=True)
|
||||
g.merge_nodes(3, 4)
|
||||
assert sorted(g.node[4]['labels']) == list(range(5))
|
||||
assert g.edges() == []
|
||||
|
||||
Reference in New Issue
Block a user