return a scipy sparse matrix for the deriv (a bit silly - it is dense, but nicer for multiplication). Init Regularization with a nP

This commit is contained in:
Lindsey Heagy
2016-05-28 15:40:28 -07:00
parent 341b98d23a
commit e8e022fcc6
2 changed files with 16 additions and 8 deletions
+3 -3
View File
@@ -19,7 +19,7 @@ class IdentityMap(object):
Utils.setKwargs(self, **kwargs)
if nP is not None:
assert type(nP) in [int, long], ' Number of parameters must be an integer.'
assert type(nP) in [int, long, np.int64], ' Number of parameters must be an integer.'
self.mesh = mesh
self._nP = nP
@@ -1492,7 +1492,7 @@ class ParametrizedBlockInLayer(IdentityMap):
self._validate_m(m) # make sure things are the right sizes
if self.mesh.dim == 2:
return self._deriv2d(m)
return sp.csr_matrix(self._deriv2d(m))
elif self.mesh.dim == 3:
return self._deriv3d(m)
return sp.csr_matrix(self._deriv3d(m))
+13 -5
View File
@@ -39,7 +39,7 @@ class RegularizationMesh(object):
if self.indActive is None:
self._nC = self.mesh.nC
else:
self._nC = sum(self.indActive)
self._nC = int(sum(self.indActive))
return self._nC
@property
@@ -304,7 +304,7 @@ class BaseRegularization(object):
mesh = None #: A SimPEG.Mesh instance.
mref = None #: Reference model.
def __init__(self, mesh, mapping=None, indActive=None, **kwargs):
def __init__(self, mesh=None, nP=None, mapping=None, indActive=None, **kwargs):
Utils.setKwargs(self, **kwargs)
assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object."
if indActive is not None and indActive.dtype != 'bool':
@@ -314,11 +314,19 @@ class BaseRegularization(object):
if indActive is not None and mapping is None:
mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size)
if mesh is None and nP is None:
raise Exception, 'either Mesh or number of parameters must be provided to the BaseRegularization'
self.regmesh = RegularizationMesh(mesh,indActive)
self.mapping = mapping or self.mapPair(mesh)
self.mapping._assertMatchesPair(self.mapPair)
self.indActive = indActive
if mesh is not None and nP is None:
nP = self.regmesh.nC
self.nP = nP
self.mapping = mapping or self.mapPair(nP=self.nP)
self.mapping._assertMatchesPair(self.mapPair)
@property
def parent(self):
"""This is the parent of the regularization."""
@@ -346,7 +354,7 @@ class BaseRegularization(object):
@property
def W(self):
"""Full regularization weighting matrix W."""
return sp.identity(self.regmesh.nC)
return sp.identity(self.nP)
@Utils.timeIt
def eval(self, m):