diff --git a/SimPEG/Maps.py b/SimPEG/Maps.py index be789a2a..a7f5a0ab 100644 --- a/SimPEG/Maps.py +++ b/SimPEG/Maps.py @@ -19,7 +19,7 @@ class IdentityMap(object): Utils.setKwargs(self, **kwargs) if nP is not None: - assert type(nP) in [int, long], ' Number of parameters must be an integer.' + assert type(nP) in [int, long, np.int64], ' Number of parameters must be an integer.' self.mesh = mesh self._nP = nP @@ -1492,7 +1492,7 @@ class ParametrizedBlockInLayer(IdentityMap): self._validate_m(m) # make sure things are the right sizes if self.mesh.dim == 2: - return self._deriv2d(m) + return sp.csr_matrix(self._deriv2d(m)) elif self.mesh.dim == 3: - return self._deriv3d(m) + return sp.csr_matrix(self._deriv3d(m)) diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index fc101a61..d4de1705 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -39,7 +39,7 @@ class RegularizationMesh(object): if self.indActive is None: self._nC = self.mesh.nC else: - self._nC = sum(self.indActive) + self._nC = int(sum(self.indActive)) return self._nC @property @@ -304,7 +304,7 @@ class BaseRegularization(object): mesh = None #: A SimPEG.Mesh instance. mref = None #: Reference model. - def __init__(self, mesh, mapping=None, indActive=None, **kwargs): + def __init__(self, mesh=None, nP=None, mapping=None, indActive=None, **kwargs): Utils.setKwargs(self, **kwargs) assert isinstance(mesh, Mesh.BaseMesh), "mesh must be a SimPEG.Mesh object." if indActive is not None and indActive.dtype != 'bool': @@ -314,11 +314,19 @@ class BaseRegularization(object): if indActive is not None and mapping is None: mapping = Maps.IdentityMap(nP=indActive.nonzero()[0].size) + if mesh is None and nP is None: + raise Exception, 'either Mesh or number of parameters must be provided to the BaseRegularization' + self.regmesh = RegularizationMesh(mesh,indActive) - self.mapping = mapping or self.mapPair(mesh) - self.mapping._assertMatchesPair(self.mapPair) self.indActive = indActive + if mesh is not None and nP is None: + nP = self.regmesh.nC + self.nP = nP + + self.mapping = mapping or self.mapPair(nP=self.nP) + self.mapping._assertMatchesPair(self.mapPair) + @property def parent(self): """This is the parent of the regularization.""" @@ -346,7 +354,7 @@ class BaseRegularization(object): @property def W(self): """Full regularization weighting matrix W.""" - return sp.identity(self.regmesh.nC) + return sp.identity(self.nP) @Utils.timeIt def eval(self, m):