From fdfc1c600e73cddf754f8af8e675f802a219405e Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Wed, 19 Feb 2014 18:49:30 -0800 Subject: [PATCH] fixed sub2ind and ind2sub (thin wrappers on numpy) as per Issue #58 --- SimPEG/Tests/test_utils.py | 15 ++++++++++++++- SimPEG/Utils/matutils.py | 39 +++++++++++++------------------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/SimPEG/Tests/test_utils.py b/SimPEG/Tests/test_utils.py index fea231f2..12d79fb4 100644 --- a/SimPEG/Tests/test_utils.py +++ b/SimPEG/Tests/test_utils.py @@ -1,6 +1,6 @@ import numpy as np import unittest -from SimPEG.Utils import mkvc, ndgrid, indexCube, sdiag, inv3X3BlockDiagonal, inv2X2BlockDiagonal +from SimPEG.Utils import mkvc, ndgrid, indexCube, sdiag, inv3X3BlockDiagonal, inv2X2BlockDiagonal,sub2ind,ind2sub from SimPEG.Tests import checkDerivative @@ -64,6 +64,19 @@ class TestSequenceFunctions(unittest.TestCase): self.assertTrue(np.all(XYZ[:, 1] == X2_test)) self.assertTrue(np.all(XYZ[:, 2] == X3_test)) + def test_sub2ind(self): + x = np.ones((5,2)) + self.assertTrue(np.all(sub2ind(x.shape, [0,0]) == [0])) + self.assertTrue(np.all(sub2ind(x.shape, [4,0]) == [4])) + self.assertTrue(np.all(sub2ind(x.shape, [0,1]) == [5])) + self.assertTrue(np.all(sub2ind(x.shape, [4,1]) == [9])) + self.assertTrue(np.all(sub2ind(x.shape, [[0,0],[4,0],[0,1],[4,1]]) == [0,4,5,9])) + + def test_ind2sub(self): + x = np.ones((5,2)) + self.assertTrue(np.all(ind2sub(x.shape, [0,4,5,9])[0] == [0,4,0,4])) + self.assertTrue(np.all(ind2sub(x.shape, [0,4,5,9])[1] == [0,0,1,1])) + def test_indexCube_2D(self): nN = np.array([3, 3]) self.assertTrue(np.all(indexCube('A', nN) == np.array([0, 1, 3, 4]))) diff --git a/SimPEG/Utils/matutils.py b/SimPEG/Utils/matutils.py index 24286cdd..b2e1daf2 100644 --- a/SimPEG/Utils/matutils.py +++ b/SimPEG/Utils/matutils.py @@ -97,35 +97,22 @@ def ndgrid(*args, **kwargs): else: return XYZ[2], XYZ[1], XYZ[0] - -def ind2sub(shape, ind): - """From the given shape, returns the subscrips of the given index""" - revshp = [] - revshp.extend(shape) - mult = [1] - for i in range(0, len(revshp)-1): - mult.extend([mult[i]*revshp[i]]) - mult = np.array(mult).reshape(len(mult)) - - sub = [] - - for i in range(0, len(shape)): - sub.extend([np.math.floor(ind / mult[i])]) - ind = ind - (np.math.floor(ind/mult[i]) * mult[i]) - return sub - +def ind2sub(shape, inds): + """From the given shape, returns the subscripts of the given index""" + if type(inds) is not np.ndarray: + inds = np.array(inds) + assert len(inds.shape) == 1, 'Indexing must be done as a 1D row vector, e.g. [3,6,6,...]' + return np.unravel_index(inds, shape, order='F') def sub2ind(shape, subs): """From the given shape, returns the index of the given subscript""" - revshp = list(shape) - mult = [1] - for i in range(0, len(revshp)-1): - mult.extend([mult[i]*revshp[i]]) - mult = np.array(mult).reshape(len(mult), 1) - - idx = np.dot((subs), (mult)) - return idx - + if type(subs) is not np.ndarray: + subs = np.array(subs) + if subs.size == len(shape): + subs = subs[np.newaxis,:] + assert subs.shape[1] == len(shape), 'Indexing must be done as a column vectors. e.g. [[3,6],[6,2],...]' + inds = np.ravel_multi_index(subs.T, shape, order='F') + return mkvc(inds) def getSubArray(A, ind): """subArray"""