mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 21:08:24 +08:00
Added testcase
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
from .spath import shortest_path
|
||||
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
|
||||
from .rag import rag_meancolor
|
||||
from .graph_cut import threshold_cut
|
||||
|
||||
__all__ = ['shortest_path',
|
||||
'MCP',
|
||||
'MCP_Geometric',
|
||||
'MCP_Connect',
|
||||
'MCP_Flexible',
|
||||
'route_through_array']
|
||||
'route_through_array',
|
||||
'rag_meancolor',
|
||||
'threshold_cut']
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import netwrokx as nx
|
||||
|
||||
class Graph(nx.Graph):
|
||||
|
||||
def merge_nodes(i,j):
|
||||
if not self.has_edge(i, j):
|
||||
raise ValueError('Cant merge non adjacent nodes')
|
||||
|
||||
# print "before ",self.order()
|
||||
for x in self.neighbors(i):
|
||||
if x == j:
|
||||
continue
|
||||
w1 = self.get_edge_data(x, i)['weight']
|
||||
w2 = -1
|
||||
if self.has_edge(x, j):
|
||||
w2 = self.get_edge_data(x, j)['weight']
|
||||
|
||||
w = max(w1, w2)
|
||||
|
||||
self.add_edge(x, j, weight=w)
|
||||
|
||||
self.node[j]['labels'] += self.node[i]['labels']
|
||||
self.remove_node(i)
|
||||
@@ -5,14 +5,15 @@ def threshold_cut(label, rag, thresh):
|
||||
|
||||
#print [rag.edges_iter(data = True)]
|
||||
to_remove = [(x,y) for x,y,d in rag.edges_iter(data = True) if d['weight'] >= thresh]
|
||||
print "edges to remove",len(to_remove)
|
||||
#print "edges to remove",len(to_remove)
|
||||
|
||||
rag.remove_edges_from(to_remove)
|
||||
|
||||
#print "to remove", to_remove
|
||||
|
||||
comps = nx.connected_components(rag)
|
||||
out = np.copy(label)
|
||||
print "comps",len(comps)
|
||||
#print "comps",len(comps)
|
||||
|
||||
for i, nodes in enumerate(comps) :
|
||||
|
||||
|
||||
@@ -27,9 +27,9 @@ class RAG(nx.Graph):
|
||||
def rag_meancolor(img,labels):
|
||||
|
||||
img = util.img_as_ubyte(img)
|
||||
if img.ndim == 3 :
|
||||
if img.ndim == 4 :
|
||||
return _construct.construct_rag_meancolor_3d(img,labels)
|
||||
elif img.ndim == 2 :
|
||||
elif img.ndim == 3 :
|
||||
return _construct.construct_rag_meancolor_2d(img,labels)
|
||||
else :
|
||||
raise ValueError("Image dimension not supported")
|
||||
|
||||
@@ -1,10 +1,29 @@
|
||||
import numpy as np
|
||||
from skimage import graph
|
||||
|
||||
def test_threshold_cut():
|
||||
arr = np.array((100,100,3),dtype='uint8')
|
||||
arr[:50,:50] = 0
|
||||
arr[:50,50:] = 1
|
||||
arr[50:,50:] = 2
|
||||
arr[50:,50:] = 3
|
||||
|
||||
img = np.zeros((100,100,3),dtype='uint8')
|
||||
img[:50,:50] = 255,255,255
|
||||
img[:50,50:] = 254,254,254
|
||||
img[50:,:50] = 2,2,2
|
||||
img[50:,50:] = 1,1,1
|
||||
|
||||
|
||||
|
||||
labels = np.zeros((100,100),dtype='uint8')
|
||||
labels[:50,:50] = 0
|
||||
labels[:50,50:] = 1
|
||||
labels[50:,:50] = 2
|
||||
labels[50:,50:] = 3
|
||||
|
||||
|
||||
#print labels
|
||||
rag = graph.rag_meancolor(img, labels)
|
||||
#print "no of edges",rag.number_of_edges()
|
||||
new_labels = graph.threshold_cut(labels, rag, 10)
|
||||
|
||||
assert new_labels.max() == 2
|
||||
#assert False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user