Cut cost is computed by Cython code

This commit is contained in:
Vighnesh Birodkar
2014-07-31 23:05:47 +05:30
parent d8c0b2e7dd
commit 07cb79cd27
3 changed files with 37 additions and 10 deletions
+9 -9
View File
@@ -1,10 +1,11 @@
import networkx as nx
import numpy as np
from scipy import sparse
from . import _ncut_cy
def DW_matrix(graph):
"""Returns the diagonal and weight matrix of a graph.
def DW_matrices(graph):
"""Returns the diagonal and weight matrices of a graph.
Parameters
----------
@@ -14,15 +15,15 @@ def DW_matrix(graph):
Returns
-------
D : csc_matrix
The diagonal matrix of the graph. `D[i,i]` is the sum of weights of all
edges incident on `i`. All other enteries are `0`.
The diagonal matrix of the graph. `D[i, i]` is the sum of weights of
all edges incident on `i`. All other enteries are `0`.
W : csc_matrix
The weight matrix of the graph. `W[i,j]` is the weight of the edge
The weight matrix of the graph. `W[i, j]` is the weight of the edge
joining `i` to `j`.
"""
#Cause sparse.eigsh prefers CSC format
W = nx.to_scipy_sparse_matrix(graph, format='csc')
entries = W.sum(0)
entries = W.sum(axis=0)
D = sparse.dia_matrix((entries, 0), shape=W.shape).tocsc()
return D, W
@@ -46,10 +47,9 @@ def ncut_cost(mask, D, W):
The cost of performing the N-cut.
"""
mask = np.array(mask)
mask_list = [np.logical_xor(mask[i], mask) for i in range(mask.shape[0])]
mask_array = np.array(mask_list)
cut = _ncut_cy.cut_cost(mask, W)
cut = float(W[mask_array].sum() / 2.0)
# Cause D has elements only along diagonal
assoc_a = D.data[mask].sum()
assoc_b = D.data[np.logical_not(mask)].sum()
+27
View File
@@ -38,3 +38,30 @@ def argmin2(cnp.float64_t[:] array):
i += 1
return i2
def cut_cost(mask, W):
mask = np.array(mask)
cdef Py_ssize_t num_rows, num_cols
cdef cnp.int32_t row, col
cdef cnp.int32_t[:] indices = W.indices
cdef cnp.int32_t[:] indptr = W.indptr
cdef cnp.float64_t[:] data = W.data
cdef cnp.int32_t row_index
cdef cnp.double_t cost = 0
num_rows = W.shape[0]
num_cols = W.shape[1]
col = 0
while col < num_cols:
row_index = indptr[col]
while row_index < indptr[col+1]:
row = indices[row_index]
if mask[row] != mask[col]:
cost += <cnp.double_t>data[row_index]
row_index += 1
col += 1
return cost*0.5
+1 -1
View File
@@ -134,7 +134,7 @@ def _ncut_relabel(rag, thresh, num_cuts, map_array):
The array which maps old labels to new ones. This is modified inside
the function.
"""
d, w = _ncut.DW_matrix(rag)
d, w = _ncut.DW_matrices(rag)
error = False
try: