mirror of
https://github.com/wassname/scikit-image.git
synced 2026-07-03 03:46:06 +08:00
Cut cost is computed by Cython code
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from scipy import sparse
|
||||
from . import _ncut_cy
|
||||
|
||||
|
||||
def DW_matrix(graph):
|
||||
"""Returns the diagonal and weight matrix of a graph.
|
||||
def DW_matrices(graph):
|
||||
"""Returns the diagonal and weight matrices of a graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -14,15 +15,15 @@ def DW_matrix(graph):
|
||||
Returns
|
||||
-------
|
||||
D : csc_matrix
|
||||
The diagonal matrix of the graph. `D[i,i]` is the sum of weights of all
|
||||
edges incident on `i`. All other enteries are `0`.
|
||||
The diagonal matrix of the graph. `D[i, i]` is the sum of weights of
|
||||
all edges incident on `i`. All other enteries are `0`.
|
||||
W : csc_matrix
|
||||
The weight matrix of the graph. `W[i,j]` is the weight of the edge
|
||||
The weight matrix of the graph. `W[i, j]` is the weight of the edge
|
||||
joining `i` to `j`.
|
||||
"""
|
||||
#Cause sparse.eigsh prefers CSC format
|
||||
W = nx.to_scipy_sparse_matrix(graph, format='csc')
|
||||
entries = W.sum(0)
|
||||
entries = W.sum(axis=0)
|
||||
D = sparse.dia_matrix((entries, 0), shape=W.shape).tocsc()
|
||||
return D, W
|
||||
|
||||
@@ -46,10 +47,9 @@ def ncut_cost(mask, D, W):
|
||||
The cost of performing the N-cut.
|
||||
"""
|
||||
mask = np.array(mask)
|
||||
mask_list = [np.logical_xor(mask[i], mask) for i in range(mask.shape[0])]
|
||||
mask_array = np.array(mask_list)
|
||||
cut = _ncut_cy.cut_cost(mask, W)
|
||||
|
||||
cut = float(W[mask_array].sum() / 2.0)
|
||||
# Cause D has elements only along diagonal
|
||||
assoc_a = D.data[mask].sum()
|
||||
assoc_b = D.data[np.logical_not(mask)].sum()
|
||||
|
||||
|
||||
@@ -38,3 +38,30 @@ def argmin2(cnp.float64_t[:] array):
|
||||
i += 1
|
||||
|
||||
return i2
|
||||
|
||||
|
||||
def cut_cost(mask, W):
|
||||
mask = np.array(mask)
|
||||
|
||||
cdef Py_ssize_t num_rows, num_cols
|
||||
cdef cnp.int32_t row, col
|
||||
cdef cnp.int32_t[:] indices = W.indices
|
||||
cdef cnp.int32_t[:] indptr = W.indptr
|
||||
cdef cnp.float64_t[:] data = W.data
|
||||
cdef cnp.int32_t row_index
|
||||
cdef cnp.double_t cost = 0
|
||||
|
||||
num_rows = W.shape[0]
|
||||
num_cols = W.shape[1]
|
||||
|
||||
col = 0
|
||||
while col < num_cols:
|
||||
row_index = indptr[col]
|
||||
while row_index < indptr[col+1]:
|
||||
row = indices[row_index]
|
||||
if mask[row] != mask[col]:
|
||||
cost += <cnp.double_t>data[row_index]
|
||||
row_index += 1
|
||||
col += 1
|
||||
|
||||
return cost*0.5
|
||||
|
||||
@@ -134,7 +134,7 @@ def _ncut_relabel(rag, thresh, num_cuts, map_array):
|
||||
The array which maps old labels to new ones. This is modified inside
|
||||
the function.
|
||||
"""
|
||||
d, w = _ncut.DW_matrix(rag)
|
||||
d, w = _ncut.DW_matrices(rag)
|
||||
error = False
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user