From abcb9cf2efca91c320458fec7a67ba8782b30df0 Mon Sep 17 00:00:00 2001 From: Vighnesh Birodkar Date: Mon, 16 Jun 2014 19:22:47 +0530 Subject: [PATCH] Added testcase --- skimage/graph/__init__.py | 6 +++++- skimage/graph/graph.py | 23 ----------------------- skimage/graph/graph_cut.py | 5 +++-- skimage/graph/rag.py | 4 ++-- skimage/graph/tests/test_rag.py | 29 ++++++++++++++++++++++++----- 5 files changed, 34 insertions(+), 33 deletions(-) delete mode 100644 skimage/graph/graph.py diff --git a/skimage/graph/__init__.py b/skimage/graph/__init__.py index a335971d..90da2088 100644 --- a/skimage/graph/__init__.py +++ b/skimage/graph/__init__.py @@ -1,9 +1,13 @@ from .spath import shortest_path from .mcp import MCP, MCP_Geometric, MCP_Connect, MCP_Flexible, route_through_array +from .rag import rag_meancolor +from .graph_cut import threshold_cut __all__ = ['shortest_path', 'MCP', 'MCP_Geometric', 'MCP_Connect', 'MCP_Flexible', - 'route_through_array'] \ No newline at end of file + 'route_through_array', + 'rag_meancolor', + 'threshold_cut'] diff --git a/skimage/graph/graph.py b/skimage/graph/graph.py deleted file mode 100644 index 2597e985..00000000 --- a/skimage/graph/graph.py +++ /dev/null @@ -1,23 +0,0 @@ -import netwrokx as nx - -class Graph(nx.Graph): - - def merge_nodes(i,j): - if not self.has_edge(i, j): - raise ValueError('Cant merge non adjacent nodes') - - # print "before ",self.order() - for x in self.neighbors(i): - if x == j: - continue - w1 = self.get_edge_data(x, i)['weight'] - w2 = -1 - if self.has_edge(x, j): - w2 = self.get_edge_data(x, j)['weight'] - - w = max(w1, w2) - - self.add_edge(x, j, weight=w) - - self.node[j]['labels'] += self.node[i]['labels'] - self.remove_node(i) diff --git a/skimage/graph/graph_cut.py b/skimage/graph/graph_cut.py index bf55db84..5d6a98e2 100644 --- a/skimage/graph/graph_cut.py +++ b/skimage/graph/graph_cut.py @@ -5,14 +5,15 @@ def threshold_cut(label, rag, thresh): #print [rag.edges_iter(data = True)] to_remove = [(x,y) for x,y,d in rag.edges_iter(data = True) if d['weight'] >= thresh] - print "edges to remove",len(to_remove) + #print "edges to remove",len(to_remove) rag.remove_edges_from(to_remove) + #print "to remove", to_remove comps = nx.connected_components(rag) out = np.copy(label) - print "comps",len(comps) + #print "comps",len(comps) for i, nodes in enumerate(comps) : diff --git a/skimage/graph/rag.py b/skimage/graph/rag.py index 9795787b..dd721959 100644 --- a/skimage/graph/rag.py +++ b/skimage/graph/rag.py @@ -27,9 +27,9 @@ class RAG(nx.Graph): def rag_meancolor(img,labels): img = util.img_as_ubyte(img) - if img.ndim == 3 : + if img.ndim == 4 : return _construct.construct_rag_meancolor_3d(img,labels) - elif img.ndim == 2 : + elif img.ndim == 3 : return _construct.construct_rag_meancolor_2d(img,labels) else : raise ValueError("Image dimension not supported") diff --git a/skimage/graph/tests/test_rag.py b/skimage/graph/tests/test_rag.py index 08d1a1cf..d24c62c4 100644 --- a/skimage/graph/tests/test_rag.py +++ b/skimage/graph/tests/test_rag.py @@ -1,10 +1,29 @@ import numpy as np +from skimage import graph def test_threshold_cut(): - arr = np.array((100,100,3),dtype='uint8') - arr[:50,:50] = 0 - arr[:50,50:] = 1 - arr[50:,50:] = 2 - arr[50:,50:] = 3 + + img = np.zeros((100,100,3),dtype='uint8') + img[:50,:50] = 255,255,255 + img[:50,50:] = 254,254,254 + img[50:,:50] = 2,2,2 + img[50:,50:] = 1,1,1 + + + + labels = np.zeros((100,100),dtype='uint8') + labels[:50,:50] = 0 + labels[:50,50:] = 1 + labels[50:,:50] = 2 + labels[50:,50:] = 3 + + + #print labels + rag = graph.rag_meancolor(img, labels) + #print "no of edges",rag.number_of_edges() + new_labels = graph.threshold_cut(labels, rag, 10) + + assert new_labels.max() == 2 + #assert False