convert indActive to a bool if an integer list is provided

This commit is contained in:
Lindsey Heagy
2016-02-24 20:28:09 -08:00
parent 4e871a43a9
commit e3af1fd94e
2 changed files with 14 additions and 11 deletions
+6 -4
View File
@@ -4,6 +4,7 @@ class RegularizationMesh(object):
def __init__(self, mesh, indActive=None):
self.mesh = mesh
assert indActive is None or indActive.dtype == 'bool', 'indActive needs to be None or a bool'
self.indActive = indActive
@property
@@ -18,10 +19,7 @@ class RegularizationMesh(object):
if self.indActive is None:
self._nC = self.mesh.nC
else:
if self.indActive.dtype == 'bool':
self._nC = sum(self.indActive)
else:
self._nC = len(self.indActive) # you shouldn't pass a vector of int 0, 1 's
self._nC = sum(self.indActive)
return self._nC
@property
@@ -199,6 +197,10 @@ class BaseRegularization(object):
def __init__(self, mesh, mapping=None, indActive=None, **kwargs):
Utils.setKwargs(self, **kwargs)
assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object."
if indActive is not None and indActive.dtype != 'bool':
tmp = indActive
indActive = np.zeros(mesh.nC, dtype=bool)
indActive[tmp] = True
self.regmesh = RegularizationMesh(mesh,indActive)
self.mapping = mapping or self.mapPair(mesh)
self.mapping._assertMatchesPair(self.mapPair)
+8 -7
View File
@@ -59,17 +59,18 @@ class RegularizationTests(unittest.TestCase):
print 'Testing Active Cells %iD'%(mesh.dim)
if mesh.dim == 1:
indAct = Utils.mkvc(mesh.gridCC <= 0.8)
indActive = Utils.mkvc(mesh.gridCC <= 0.8)
elif mesh.dim == 2:
indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5)
indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5)
elif mesh.dim == 3:
indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5)
indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5)
mapping = Maps.IdentityMap(nP=indAct.nonzero()[0].size)
mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size)
reg = r(mesh, mapping=mapping, indActive=indAct)
m = np.random.rand(mesh.nC)[indAct]
reg.mref = np.ones_like(m)*np.mean(m)
for indAct in [indActive, indActive.nonzero()[0]]: # test both bool and integers
reg = r(mesh, mapping=mapping, indActive=indAct)
m = np.random.rand(mesh.nC)[indAct]
reg.mref = np.ones_like(m)*np.mean(m)
print 'Check: phi_m (mref) = %f' %reg.eval(reg.mref)
passed = reg.eval(reg.mref) < TOL