diff --git a/SimPEG/Model.py b/SimPEG/Model.py index 6eb3c297..8482edde 100644 --- a/SimPEG/Model.py +++ b/SimPEG/Model.py @@ -50,6 +50,11 @@ class BaseModel(object): """ return sp.identity(m.size) + @property + def nP(self): + """Number of parameters in the model.""" + return self.mesh.nC + def example(self, modelType=None): return np.random.rand(self.mesh.nC) diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index e592d643..ea6cd753 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -1,7 +1,93 @@ -from SimPEG import Utils, np, sp +from SimPEG import Utils, Model, np, sp class BaseRegularization(object): - """**Regularization** + """ + **Base Regularization Class** + + This is used to regularize the model space:: + + reg = Regularization(mesh, model) + + """ + + __metaclass__ = Utils.Save.Savable + + modelPair = Model.BaseModel #: Some regularizations only work on specific models + + mesh = None #: A SimPEG.Mesh instance. + model = None #: A SimPEG.Model instance. + + counter = None + + def __init__(self, mesh, model, **kwargs): + Utils.setKwargs(self, **kwargs) + assert isinstance(model, self.modelPair), "Incorrect model for this regularization" + self.mesh = mesh + self.model = model + + @property + def mref(self): + if getattr(self, '_mref', None) is None: + return np.zeros(self.model.nP); + return self._mref + @mref.setter + def mref(self, value): + self._mref = value + + + @property + def W(self): + """Full regularization weighting matrix W.""" + return sp.identity(self.model.nP) + + + @Utils.timeIt + def modelObj(self, m): + r = self.W * (m - self.mref) + return 0.5*r.dot(r) + + @Utils.timeIt + def modelObjDeriv(self, m): + """ + + The regularization is: + + .. math:: + + R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})} + + So the derivative is straight forward: + + .. math:: + + R(m) = \mathbf{W^\\top W (m-m_\\text{ref})} + + """ + return self.W.T * ( self.W * (m - self.mref) ) + + @Utils.timeIt + def modelObj2Deriv(self): + """ + + The regularization is: + + .. math:: + + R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})} + + So the second derivative is straight forward: + + .. math:: + + R(m) = \mathbf{W^\\top W} + + """ + return self.W.T * self.W + + + +class Tikhonov(BaseRegularization): + """**Tikhonov Regularization** Here we will define regularization of a model, m, in general however, this should be thought of as (m-m_ref) but otherwise it is exactly the same: @@ -83,8 +169,6 @@ class BaseRegularization(object): """ - __metaclass__ = Utils.Save.Savable - alpha_s = Utils.dependentProperty('_alpha_s', 1e-6, ['_W', '_Ws'], "Smallness weight") alpha_x = Utils.dependentProperty('_alpha_x', 1.0, ['_W', '_Wx'], "Weight for the first derivative in the x direction") alpha_y = Utils.dependentProperty('_alpha_y', 1.0, ['_W', '_Wy'], "Weight for the first derivative in the y direction") @@ -93,20 +177,8 @@ class BaseRegularization(object): alpha_yy = Utils.dependentProperty('_alpha_yy', 0.0, ['_W', '_Wyy'], "Weight for the second derivative in the y direction") alpha_zz = Utils.dependentProperty('_alpha_zz', 0.0, ['_W', '_Wzz'], "Weight for the second derivative in the z direction") - counter = None - - def __init__(self, mesh, **kwargs): - Utils.setKwargs(self, **kwargs) - self.mesh = mesh - - @property - def mref(self): - if getattr(self, '_mref', None) is None: - return np.zeros(self.mesh.nC); - return self._mref - @mref.setter - def mref(self, value): - self._mref = value + def __init__(self, mesh, model, **kwargs): + BaseRegularization.__init__(self, mesh, model, **kwargs) @property def Ws(self): @@ -160,7 +232,6 @@ class BaseRegularization(object): self._Wzz = Utils.sdiag((self.mesh.vol*self.alpha_zz)**0.5)*self.mesh.faceDivz*self.mesh.cellGradz return self._Wzz - @property def W(self): """Full regularization matrix W""" @@ -173,47 +244,3 @@ class BaseRegularization(object): self._W = sp.vstack(wlist) return self._W - - @Utils.timeIt - def modelObj(self, m): - r = self.W * (m - self.mref) - return 0.5*r.dot(r) - - @Utils.timeIt - def modelObjDeriv(self, m): - """ - - The regularization is: - - .. math:: - - R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})} - - So the derivative is straight forward: - - .. math:: - - R(m) = \mathbf{W^\\top W (m-m_\\text{ref})} - - """ - return self.W.T * ( self.W * (m - self.mref) ) - - @Utils.timeIt - def modelObj2Deriv(self): - """ - - The regularization is: - - .. math:: - - R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})} - - So the second derivative is straight forward: - - .. math:: - - R(m) = \mathbf{W^\\top W} - - """ - return self.W.T * self.W - diff --git a/SimPEG/Tests/test_problem.py b/SimPEG/Tests/test_problem.py deleted file mode 100644 index e1d30337..00000000 --- a/SimPEG/Tests/test_problem.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np -import unittest -from SimPEG import * -from TestUtils import checkDerivative -from scipy.sparse.linalg import dsolve - - -class ProblemTests(unittest.TestCase): - - def setUp(self): - - a = np.array([1, 1, 1]) - b = np.array([1, 2]) - c = np.array([1, 4]) - self.mesh2 = Mesh.TensorMesh([a, b], np.array([3, 5])) - self.p2 = Problem.BaseProblem(self.mesh2, None) - self.reg = Regularization.BaseRegularization(self.mesh2) - - def test_regularization(self): - derChk = lambda m: [self.reg.modelObj(m), self.reg.modelObjDeriv(m)] - mSynth = np.random.randn(self.mesh2.nC) - checkDerivative(derChk, mSynth, plotIt=False) - - -if __name__ == '__main__': - unittest.main() diff --git a/SimPEG/Tests/test_regularization.py b/SimPEG/Tests/test_regularization.py new file mode 100644 index 00000000..129948e1 --- /dev/null +++ b/SimPEG/Tests/test_regularization.py @@ -0,0 +1,30 @@ +import numpy as np +import unittest +from SimPEG import * +from TestUtils import checkDerivative +from scipy.sparse.linalg import dsolve +import inspect + + +class RegularizationTests(unittest.TestCase): + + def setUp(self): + self.mesh2 = Mesh.TensorMesh([3, 2]) + + def test_regularization(self): + for R in dir(Regularization): + r = getattr(Regularization, R) + if not inspect.isclass(r): continue + if not issubclass(r, Regularization.BaseRegularization): + continue + # if 'Regularization' not in R: continue + print 'Check:', R + model = r.modelPair(self.mesh2) + reg = r(self.mesh2, model) + m = model.example() + passed = checkDerivative(lambda m : [reg.modelObj(m), reg.modelObjDeriv(m)], m, plotIt=False) + self.assertTrue(passed) + + +if __name__ == '__main__': + unittest.main()