mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-02 00:46:36 +08:00
fixed sub2ind and ind2sub (thin wrappers on numpy) as per Issue #58
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from SimPEG.Utils import mkvc, ndgrid, indexCube, sdiag, inv3X3BlockDiagonal, inv2X2BlockDiagonal
|
||||
from SimPEG.Utils import mkvc, ndgrid, indexCube, sdiag, inv3X3BlockDiagonal, inv2X2BlockDiagonal,sub2ind,ind2sub
|
||||
from SimPEG.Tests import checkDerivative
|
||||
|
||||
|
||||
@@ -64,6 +64,19 @@ class TestSequenceFunctions(unittest.TestCase):
|
||||
self.assertTrue(np.all(XYZ[:, 1] == X2_test))
|
||||
self.assertTrue(np.all(XYZ[:, 2] == X3_test))
|
||||
|
||||
def test_sub2ind(self):
|
||||
x = np.ones((5,2))
|
||||
self.assertTrue(np.all(sub2ind(x.shape, [0,0]) == [0]))
|
||||
self.assertTrue(np.all(sub2ind(x.shape, [4,0]) == [4]))
|
||||
self.assertTrue(np.all(sub2ind(x.shape, [0,1]) == [5]))
|
||||
self.assertTrue(np.all(sub2ind(x.shape, [4,1]) == [9]))
|
||||
self.assertTrue(np.all(sub2ind(x.shape, [[0,0],[4,0],[0,1],[4,1]]) == [0,4,5,9]))
|
||||
|
||||
def test_ind2sub(self):
|
||||
x = np.ones((5,2))
|
||||
self.assertTrue(np.all(ind2sub(x.shape, [0,4,5,9])[0] == [0,4,0,4]))
|
||||
self.assertTrue(np.all(ind2sub(x.shape, [0,4,5,9])[1] == [0,0,1,1]))
|
||||
|
||||
def test_indexCube_2D(self):
|
||||
nN = np.array([3, 3])
|
||||
self.assertTrue(np.all(indexCube('A', nN) == np.array([0, 1, 3, 4])))
|
||||
|
||||
+13
-26
@@ -97,35 +97,22 @@ def ndgrid(*args, **kwargs):
|
||||
else:
|
||||
return XYZ[2], XYZ[1], XYZ[0]
|
||||
|
||||
|
||||
def ind2sub(shape, ind):
|
||||
"""From the given shape, returns the subscrips of the given index"""
|
||||
revshp = []
|
||||
revshp.extend(shape)
|
||||
mult = [1]
|
||||
for i in range(0, len(revshp)-1):
|
||||
mult.extend([mult[i]*revshp[i]])
|
||||
mult = np.array(mult).reshape(len(mult))
|
||||
|
||||
sub = []
|
||||
|
||||
for i in range(0, len(shape)):
|
||||
sub.extend([np.math.floor(ind / mult[i])])
|
||||
ind = ind - (np.math.floor(ind/mult[i]) * mult[i])
|
||||
return sub
|
||||
|
||||
def ind2sub(shape, inds):
|
||||
"""From the given shape, returns the subscripts of the given index"""
|
||||
if type(inds) is not np.ndarray:
|
||||
inds = np.array(inds)
|
||||
assert len(inds.shape) == 1, 'Indexing must be done as a 1D row vector, e.g. [3,6,6,...]'
|
||||
return np.unravel_index(inds, shape, order='F')
|
||||
|
||||
def sub2ind(shape, subs):
|
||||
"""From the given shape, returns the index of the given subscript"""
|
||||
revshp = list(shape)
|
||||
mult = [1]
|
||||
for i in range(0, len(revshp)-1):
|
||||
mult.extend([mult[i]*revshp[i]])
|
||||
mult = np.array(mult).reshape(len(mult), 1)
|
||||
|
||||
idx = np.dot((subs), (mult))
|
||||
return idx
|
||||
|
||||
if type(subs) is not np.ndarray:
|
||||
subs = np.array(subs)
|
||||
if subs.size == len(shape):
|
||||
subs = subs[np.newaxis,:]
|
||||
assert subs.shape[1] == len(shape), 'Indexing must be done as a column vectors. e.g. [[3,6],[6,2],...]'
|
||||
inds = np.ravel_multi_index(subs.T, shape, order='F')
|
||||
return mkvc(inds)
|
||||
|
||||
def getSubArray(A, ind):
|
||||
"""subArray"""
|
||||
|
||||
Reference in New Issue
Block a user