From e3af1fd94e9828caa57a931074fa01a91a3ebf1b Mon Sep 17 00:00:00 2001 From: Lindsey Heagy Date: Wed, 24 Feb 2016 20:28:09 -0800 Subject: [PATCH] convert indActive to a bool if an integer list is provided --- SimPEG/Regularization.py | 10 ++++++---- tests/base/test_regularization.py | 15 ++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index 31013dc5..d0fc21b1 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -4,6 +4,7 @@ class RegularizationMesh(object): def __init__(self, mesh, indActive=None): self.mesh = mesh + assert indActive is None or indActive.dtype == 'bool', 'indActive needs to be None or a bool' self.indActive = indActive @property @@ -18,10 +19,7 @@ class RegularizationMesh(object): if self.indActive is None: self._nC = self.mesh.nC else: - if self.indActive.dtype == 'bool': - self._nC = sum(self.indActive) - else: - self._nC = len(self.indActive) # you shouldn't pass a vector of int 0, 1 's + self._nC = sum(self.indActive) return self._nC @property @@ -199,6 +197,10 @@ class BaseRegularization(object): def __init__(self, mesh, mapping=None, indActive=None, **kwargs): Utils.setKwargs(self, **kwargs) assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object." + if indActive is not None and indActive.dtype != 'bool': + tmp = indActive + indActive = np.zeros(mesh.nC, dtype=bool) + indActive[tmp] = True self.regmesh = RegularizationMesh(mesh,indActive) self.mapping = mapping or self.mapPair(mesh) self.mapping._assertMatchesPair(self.mapPair) diff --git a/tests/base/test_regularization.py b/tests/base/test_regularization.py index 614d3158..52cef349 100644 --- a/tests/base/test_regularization.py +++ b/tests/base/test_regularization.py @@ -59,17 +59,18 @@ class RegularizationTests(unittest.TestCase): print 'Testing Active Cells %iD'%(mesh.dim) if mesh.dim == 1: - indAct = Utils.mkvc(mesh.gridCC <= 0.8) + indActive = Utils.mkvc(mesh.gridCC <= 0.8) elif mesh.dim == 2: - indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5) + indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5) elif mesh.dim == 3: - indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5) + indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5) - mapping = Maps.IdentityMap(nP=indAct.nonzero()[0].size) + mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size) - reg = r(mesh, mapping=mapping, indActive=indAct) - m = np.random.rand(mesh.nC)[indAct] - reg.mref = np.ones_like(m)*np.mean(m) + for indAct in [indActive, indActive.nonzero()[0]]: # test both bool and integers + reg = r(mesh, mapping=mapping, indActive=indAct) + m = np.random.rand(mesh.nC)[indAct] + reg.mref = np.ones_like(m)*np.mean(m) print 'Check: phi_m (mref) = %f' %reg.eval(reg.mref) passed = reg.eval(reg.mref) < TOL