diff --git a/skimage/graph/_ncut.py b/skimage/graph/_ncut.py index 3e961ad4..4874b252 100644 --- a/skimage/graph/_ncut.py +++ b/skimage/graph/_ncut.py @@ -1,10 +1,11 @@ import networkx as nx import numpy as np from scipy import sparse +from . import _ncut_cy -def DW_matrix(graph): - """Returns the diagonal and weight matrix of a graph. +def DW_matrices(graph): + """Returns the diagonal and weight matrices of a graph. Parameters ---------- @@ -14,15 +15,15 @@ def DW_matrix(graph): Returns ------- D : csc_matrix - The diagonal matrix of the graph. `D[i,i]` is the sum of weights of all - edges incident on `i`. All other enteries are `0`. + The diagonal matrix of the graph. `D[i, i]` is the sum of weights of + all edges incident on `i`. All other enteries are `0`. W : csc_matrix - The weight matrix of the graph. `W[i,j]` is the weight of the edge + The weight matrix of the graph. `W[i, j]` is the weight of the edge joining `i` to `j`. """ #Cause sparse.eigsh prefers CSC format W = nx.to_scipy_sparse_matrix(graph, format='csc') - entries = W.sum(0) + entries = W.sum(axis=0) D = sparse.dia_matrix((entries, 0), shape=W.shape).tocsc() return D, W @@ -46,10 +47,9 @@ def ncut_cost(mask, D, W): The cost of performing the N-cut. """ mask = np.array(mask) - mask_list = [np.logical_xor(mask[i], mask) for i in range(mask.shape[0])] - mask_array = np.array(mask_list) + cut = _ncut_cy.cut_cost(mask, W) - cut = float(W[mask_array].sum() / 2.0) + # Cause D has elements only along diagonal assoc_a = D.data[mask].sum() assoc_b = D.data[np.logical_not(mask)].sum() diff --git a/skimage/graph/_ncut_cy.pyx b/skimage/graph/_ncut_cy.pyx index d8156ba4..8c416cbc 100644 --- a/skimage/graph/_ncut_cy.pyx +++ b/skimage/graph/_ncut_cy.pyx @@ -38,3 +38,30 @@ def argmin2(cnp.float64_t[:] array): i += 1 return i2 + + +def cut_cost(mask, W): + mask = np.array(mask) + + cdef Py_ssize_t num_rows, num_cols + cdef cnp.int32_t row, col + cdef cnp.int32_t[:] indices = W.indices + cdef cnp.int32_t[:] indptr = W.indptr + cdef cnp.float64_t[:] data = W.data + cdef cnp.int32_t row_index + cdef cnp.double_t cost = 0 + + num_rows = W.shape[0] + num_cols = W.shape[1] + + col = 0 + while col < num_cols: + row_index = indptr[col] + while row_index < indptr[col+1]: + row = indices[row_index] + if mask[row] != mask[col]: + cost += data[row_index] + row_index += 1 + col += 1 + + return cost*0.5 diff --git a/skimage/graph/graph_cut.py b/skimage/graph/graph_cut.py index 1010877f..af5048fd 100644 --- a/skimage/graph/graph_cut.py +++ b/skimage/graph/graph_cut.py @@ -134,7 +134,7 @@ def _ncut_relabel(rag, thresh, num_cuts, map_array): The array which maps old labels to new ones. This is modified inside the function. """ - d, w = _ncut.DW_matrix(rag) + d, w = _ncut.DW_matrices(rag) error = False try: