mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-02 07:05:12 +08:00
convert indActive to a bool if an integer list is provided
This commit is contained in:
@@ -4,6 +4,7 @@ class RegularizationMesh(object):
|
||||
|
||||
def __init__(self, mesh, indActive=None):
|
||||
self.mesh = mesh
|
||||
assert indActive is None or indActive.dtype == 'bool', 'indActive needs to be None or a bool'
|
||||
self.indActive = indActive
|
||||
|
||||
@property
|
||||
@@ -18,10 +19,7 @@ class RegularizationMesh(object):
|
||||
if self.indActive is None:
|
||||
self._nC = self.mesh.nC
|
||||
else:
|
||||
if self.indActive.dtype == 'bool':
|
||||
self._nC = sum(self.indActive)
|
||||
else:
|
||||
self._nC = len(self.indActive) # you shouldn't pass a vector of int 0, 1 's
|
||||
self._nC = sum(self.indActive)
|
||||
return self._nC
|
||||
|
||||
@property
|
||||
@@ -199,6 +197,10 @@ class BaseRegularization(object):
|
||||
def __init__(self, mesh, mapping=None, indActive=None, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object."
|
||||
if indActive is not None and indActive.dtype != 'bool':
|
||||
tmp = indActive
|
||||
indActive = np.zeros(mesh.nC, dtype=bool)
|
||||
indActive[tmp] = True
|
||||
self.regmesh = RegularizationMesh(mesh,indActive)
|
||||
self.mapping = mapping or self.mapPair(mesh)
|
||||
self.mapping._assertMatchesPair(self.mapPair)
|
||||
|
||||
@@ -59,17 +59,18 @@ class RegularizationTests(unittest.TestCase):
|
||||
print 'Testing Active Cells %iD'%(mesh.dim)
|
||||
|
||||
if mesh.dim == 1:
|
||||
indAct = Utils.mkvc(mesh.gridCC <= 0.8)
|
||||
indActive = Utils.mkvc(mesh.gridCC <= 0.8)
|
||||
elif mesh.dim == 2:
|
||||
indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5)
|
||||
indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5)
|
||||
elif mesh.dim == 3:
|
||||
indAct = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5)
|
||||
indActive = Utils.mkvc(mesh.gridCC[:,-1] <= 2*np.sin(2*np.pi*mesh.gridCC[:,0])+0.5 * 2*np.sin(2*np.pi*mesh.gridCC[:,1])+0.5)
|
||||
|
||||
mapping = Maps.IdentityMap(nP=indAct.nonzero()[0].size)
|
||||
mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size)
|
||||
|
||||
reg = r(mesh, mapping=mapping, indActive=indAct)
|
||||
m = np.random.rand(mesh.nC)[indAct]
|
||||
reg.mref = np.ones_like(m)*np.mean(m)
|
||||
for indAct in [indActive, indActive.nonzero()[0]]: # test both bool and integers
|
||||
reg = r(mesh, mapping=mapping, indActive=indAct)
|
||||
m = np.random.rand(mesh.nC)[indAct]
|
||||
reg.mref = np.ones_like(m)*np.mean(m)
|
||||
|
||||
print 'Check: phi_m (mref) = %f' %reg.eval(reg.mref)
|
||||
passed = reg.eval(reg.mref) < TOL
|
||||
|
||||
Reference in New Issue
Block a user