Added testcase

This commit is contained in:
Vighnesh Birodkar
2014-06-16 19:22:47 +05:30
parent a6c9a5a2a7
commit abcb9cf2ef
5 changed files with 34 additions and 33 deletions
+5 -1
View File
@@ -1,9 +1,13 @@
from .spath import shortest_path
from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array
from .rag import rag_meancolor
from .graph_cut import threshold_cut
__all__ = ['shortest_path',
'MCP',
'MCP_Geometric',
'MCP_Connect',
'MCP_Flexible',
'route_through_array']
'route_through_array',
'rag_meancolor',
'threshold_cut']
-23
View File
@@ -1,23 +0,0 @@
import netwrokx as nx
class Graph(nx.Graph):
def merge_nodes(i,j):
if not self.has_edge(i, j):
raise ValueError('Cant merge non adjacent nodes')
# print "before ",self.order()
for x in self.neighbors(i):
if x == j:
continue
w1 = self.get_edge_data(x, i)['weight']
w2 = -1
if self.has_edge(x, j):
w2 = self.get_edge_data(x, j)['weight']
w = max(w1, w2)
self.add_edge(x, j, weight=w)
self.node[j]['labels'] += self.node[i]['labels']
self.remove_node(i)
+3 -2
View File
@@ -5,14 +5,15 @@ def threshold_cut(label, rag, thresh):
#print [rag.edges_iter(data = True)]
to_remove = [(x,y) for x,y,d in rag.edges_iter(data = True) if d['weight'] >= thresh]
print "edges to remove",len(to_remove)
#print "edges to remove",len(to_remove)
rag.remove_edges_from(to_remove)
#print "to remove", to_remove
comps = nx.connected_components(rag)
out = np.copy(label)
print "comps",len(comps)
#print "comps",len(comps)
for i, nodes in enumerate(comps) :
+2 -2
View File
@@ -27,9 +27,9 @@ class RAG(nx.Graph):
def rag_meancolor(img,labels):
img = util.img_as_ubyte(img)
if img.ndim == 3 :
if img.ndim == 4 :
return _construct.construct_rag_meancolor_3d(img,labels)
elif img.ndim == 2 :
elif img.ndim == 3 :
return _construct.construct_rag_meancolor_2d(img,labels)
else :
raise ValueError("Image dimension not supported")
+24 -5
View File
@@ -1,10 +1,29 @@
import numpy as np
from skimage import graph
def test_threshold_cut():
arr = np.array((100,100,3),dtype='uint8')
arr[:50,:50] = 0
arr[:50,50:] = 1
arr[50:,50:] = 2
arr[50:,50:] = 3
img = np.zeros((100,100,3),dtype='uint8')
img[:50,:50] = 255,255,255
img[:50,50:] = 254,254,254
img[50:,:50] = 2,2,2
img[50:,50:] = 1,1,1
labels = np.zeros((100,100),dtype='uint8')
labels[:50,:50] = 0
labels[:50,50:] = 1
labels[50:,:50] = 2
labels[50:,50:] = 3
#print labels
rag = graph.rag_meancolor(img, labels)
#print "no of edges",rag.number_of_edges()
new_labels = graph.threshold_cut(labels, rag, 10)
assert new_labels.max() == 2
#assert False