Added tests

This commit is contained in:
Vighnesh Birodkar
2015-01-27 22:31:15 +05:30
parent 05a996a846
commit ec17cb29c0
2 changed files with 55 additions and 5 deletions
+5 -4
View File
@@ -24,8 +24,9 @@ def _revalidate_node_edges(rag, node, heap_list):
for nbr in rag.neighbors(node):
data = rag[node][nbr]
try:
# invalidate existing neghbors of `dst`, they have new weights
# invalidate edges incident on `dst`, they have new weights
data['heap item'][3] = False
_invalidate_edge(rag, node, nbr)
except KeyError:
# will handle the case where the edge did not exist in the existing
# graph
@@ -38,7 +39,7 @@ def _revalidate_node_edges(rag, node, heap_list):
def _rename_node(graph, node_id, copy_id):
""" Renames `node_id` in `graph` to `copy_id`. """
""" Rename `node_id` in `graph` to `copy_id`. """
graph._add_node_silent(copy_id)
graph.node[copy_id] = graph.node[node_id]
@@ -70,9 +71,9 @@ def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge,
thresh : float
Regions connected by an edge with weight smaller than `thresh` are
merged.
rag_copy : bool, optional
rag_copy : bool
If set, the RAG copied before modifying.
in_place_merge : bool, optional
in_place_merge : bool
If set, the nodes are merged in place. Otherwise, a new node is
created for each merge..
merge_func : callable
+50 -1
View File
@@ -2,7 +2,7 @@ import numpy as np
from skimage import graph
from skimage._shared.version_requirements import is_installed
from numpy.testing.decorators import skipif
from skimage import segmentation
from skimage import segmentation, io
from numpy import testing
@@ -108,3 +108,52 @@ def test_rag_error():
labels[5:, :] = 1
testing.assert_raises(ValueError, graph.rag_mean_color, img, labels,
2, 'non existant mode')
@skipif(not is_installed('networkx'))
def test_rag_error():
img = np.zeros((10, 10, 3), dtype='uint8')
labels = np.zeros((10, 10), dtype='uint8')
labels[:5, :] = 0
labels[5:, :] = 1
testing.assert_raises(ValueError, graph.rag_mean_color, img, labels,
2, 'non existant mode')
def _weight_mean_color(graph, src, dst, n):
#print 'merging
diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
diff = np.linalg.norm(diff)
return diff
def _pre_merge_mean_color(graph, src, dst):
graph.node[dst]['total color'] += graph.node[src]['total color']
graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
graph.node[dst]['pixel count'])
def merge_hierarchical_mean_color(labels, rag, thresh, rag_copy=True,
in_place_merge=False):
return graph.merge_hierarchical(labels, rag, thresh, rag_copy,
in_place_merge, _pre_merge_mean_color,
_weight_mean_color)
@skipif(not is_installed('networkx'))
def test_rag_hierarchical():
img = np.zeros((8, 8, 3), dtype='uint8')
labels = np.zeros((8, 8), dtype='uint8')
img[:, :, :] = 128
labels[:,:] = 1
img[0:4,0:4,:] = 255,255,255
labels[0:4, 0:4] = 2
img[4:, 0:4,:] = 0,0,0
labels[4:, 0:4] = 3
g = graph.rag_mean_color(img, labels)
result = merge_hierarchical_mean_color(labels, g, 300)
assert len(np.unique(result)) == 1
io.imsave('/home/vighnesh/Desktop/test.png', img)