mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 23:40:00 +08:00
return a scipy sparse matrix for the deriv (a bit silly - it is dense, but nicer for multiplication). Init Regularization with a nP
This commit is contained in:
+3
-3
@@ -19,7 +19,7 @@ class IdentityMap(object):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
|
||||
if nP is not None:
|
||||
assert type(nP) in [int, long], ' Number of parameters must be an integer.'
|
||||
assert type(nP) in [int, long, np.int64], ' Number of parameters must be an integer.'
|
||||
|
||||
self.mesh = mesh
|
||||
self._nP = nP
|
||||
@@ -1492,7 +1492,7 @@ class ParametrizedBlockInLayer(IdentityMap):
|
||||
self._validate_m(m) # make sure things are the right sizes
|
||||
|
||||
if self.mesh.dim == 2:
|
||||
return self._deriv2d(m)
|
||||
return sp.csr_matrix(self._deriv2d(m))
|
||||
elif self.mesh.dim == 3:
|
||||
return self._deriv3d(m)
|
||||
return sp.csr_matrix(self._deriv3d(m))
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class RegularizationMesh(object):
|
||||
if self.indActive is None:
|
||||
self._nC = self.mesh.nC
|
||||
else:
|
||||
self._nC = sum(self.indActive)
|
||||
self._nC = int(sum(self.indActive))
|
||||
return self._nC
|
||||
|
||||
@property
|
||||
@@ -304,7 +304,7 @@ class BaseRegularization(object):
|
||||
mesh = None #: A SimPEG.Mesh instance.
|
||||
mref = None #: Reference model.
|
||||
|
||||
def __init__(self, mesh, mapping=None, indActive=None, **kwargs):
|
||||
def __init__(self, mesh=None, nP=None, mapping=None, indActive=None, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object."
|
||||
if indActive is not None and indActive.dtype != 'bool':
|
||||
@@ -314,11 +314,19 @@ class BaseRegularization(object):
|
||||
if indActive is not None and mapping is None:
|
||||
mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size)
|
||||
|
||||
if mesh is None and nP is None:
|
||||
raise Exception, 'either Mesh or number of parameters must be provided to the BaseRegularization'
|
||||
|
||||
self.regmesh = RegularizationMesh(mesh,indActive)
|
||||
self.mapping = mapping or self.mapPair(mesh)
|
||||
self.mapping._assertMatchesPair(self.mapPair)
|
||||
self.indActive = indActive
|
||||
|
||||
if mesh is not None and nP is None:
|
||||
nP = self.regmesh.nC
|
||||
self.nP = nP
|
||||
|
||||
self.mapping = mapping or self.mapPair(nP=self.nP)
|
||||
self.mapping._assertMatchesPair(self.mapPair)
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
"""This is the parent of the regularization."""
|
||||
@@ -346,7 +354,7 @@ class BaseRegularization(object):
|
||||
@property
|
||||
def W(self):
|
||||
"""Full regularization weighting matrix W."""
|
||||
return sp.identity(self.regmesh.nC)
|
||||
return sp.identity(self.nP)
|
||||
|
||||
@Utils.timeIt
|
||||
def eval(self, m):
|
||||
|
||||
Reference in New Issue
Block a user