mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 11:26:48 +08:00
Regularization is on the Model, so it needs to take that as input.
Testing is completed on every SimPEG regularization object by default.
This commit is contained in:
@@ -50,6 +50,11 @@ class BaseModel(object):
|
||||
"""
|
||||
return sp.identity(m.size)
|
||||
|
||||
@property
|
||||
def nP(self):
|
||||
"""Number of parameters in the model."""
|
||||
return self.mesh.nC
|
||||
|
||||
def example(self, modelType=None):
|
||||
return np.random.rand(self.mesh.nC)
|
||||
|
||||
|
||||
+90
-63
@@ -1,7 +1,93 @@
|
||||
from SimPEG import Utils, np, sp
|
||||
from SimPEG import Utils, Model, np, sp
|
||||
|
||||
class BaseRegularization(object):
|
||||
"""**Regularization**
|
||||
"""
|
||||
**Base Regularization Class**
|
||||
|
||||
This is used to regularize the model space::
|
||||
|
||||
reg = Regularization(mesh, model)
|
||||
|
||||
"""
|
||||
|
||||
__metaclass__ = Utils.Save.Savable
|
||||
|
||||
modelPair = Model.BaseModel #: Some regularizations only work on specific models
|
||||
|
||||
mesh = None #: A SimPEG.Mesh instance.
|
||||
model = None #: A SimPEG.Model instance.
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh, model, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
assert isinstance(model, self.modelPair), "Incorrect model for this regularization"
|
||||
self.mesh = mesh
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def mref(self):
|
||||
if getattr(self, '_mref', None) is None:
|
||||
return np.zeros(self.model.nP);
|
||||
return self._mref
|
||||
@mref.setter
|
||||
def mref(self, value):
|
||||
self._mref = value
|
||||
|
||||
|
||||
@property
|
||||
def W(self):
|
||||
"""Full regularization weighting matrix W."""
|
||||
return sp.identity(self.model.nP)
|
||||
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObj(self, m):
|
||||
r = self.W * (m - self.mref)
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObjDeriv(self, m):
|
||||
"""
|
||||
|
||||
The regularization is:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
So the derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \mathbf{W^\\top W (m-m_\\text{ref})}
|
||||
|
||||
"""
|
||||
return self.W.T * ( self.W * (m - self.mref) )
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObj2Deriv(self):
|
||||
"""
|
||||
|
||||
The regularization is:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
So the second derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \mathbf{W^\\top W}
|
||||
|
||||
"""
|
||||
return self.W.T * self.W
|
||||
|
||||
|
||||
|
||||
class Tikhonov(BaseRegularization):
|
||||
"""**Tikhonov Regularization**
|
||||
|
||||
Here we will define regularization of a model, m, in general however, this should be thought of as (m-m_ref) but otherwise it is exactly the same:
|
||||
|
||||
@@ -83,8 +169,6 @@ class BaseRegularization(object):
|
||||
|
||||
"""
|
||||
|
||||
__metaclass__ = Utils.Save.Savable
|
||||
|
||||
alpha_s = Utils.dependentProperty('_alpha_s', 1e-6, ['_W', '_Ws'], "Smallness weight")
|
||||
alpha_x = Utils.dependentProperty('_alpha_x', 1.0, ['_W', '_Wx'], "Weight for the first derivative in the x direction")
|
||||
alpha_y = Utils.dependentProperty('_alpha_y', 1.0, ['_W', '_Wy'], "Weight for the first derivative in the y direction")
|
||||
@@ -93,20 +177,8 @@ class BaseRegularization(object):
|
||||
alpha_yy = Utils.dependentProperty('_alpha_yy', 0.0, ['_W', '_Wyy'], "Weight for the second derivative in the y direction")
|
||||
alpha_zz = Utils.dependentProperty('_alpha_zz', 0.0, ['_W', '_Wzz'], "Weight for the second derivative in the z direction")
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
self.mesh = mesh
|
||||
|
||||
@property
|
||||
def mref(self):
|
||||
if getattr(self, '_mref', None) is None:
|
||||
return np.zeros(self.mesh.nC);
|
||||
return self._mref
|
||||
@mref.setter
|
||||
def mref(self, value):
|
||||
self._mref = value
|
||||
def __init__(self, mesh, model, **kwargs):
|
||||
BaseRegularization.__init__(self, mesh, model, **kwargs)
|
||||
|
||||
@property
|
||||
def Ws(self):
|
||||
@@ -160,7 +232,6 @@ class BaseRegularization(object):
|
||||
self._Wzz = Utils.sdiag((self.mesh.vol*self.alpha_zz)**0.5)*self.mesh.faceDivz*self.mesh.cellGradz
|
||||
return self._Wzz
|
||||
|
||||
|
||||
@property
|
||||
def W(self):
|
||||
"""Full regularization matrix W"""
|
||||
@@ -173,47 +244,3 @@ class BaseRegularization(object):
|
||||
self._W = sp.vstack(wlist)
|
||||
return self._W
|
||||
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObj(self, m):
|
||||
r = self.W * (m - self.mref)
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObjDeriv(self, m):
|
||||
"""
|
||||
|
||||
The regularization is:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
So the derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \mathbf{W^\\top W (m-m_\\text{ref})}
|
||||
|
||||
"""
|
||||
return self.W.T * ( self.W * (m - self.mref) )
|
||||
|
||||
@Utils.timeIt
|
||||
def modelObj2Deriv(self):
|
||||
"""
|
||||
|
||||
The regularization is:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
So the second derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \mathbf{W^\\top W}
|
||||
|
||||
"""
|
||||
return self.W.T * self.W
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from SimPEG import *
|
||||
from TestUtils import checkDerivative
|
||||
from scipy.sparse.linalg import dsolve
|
||||
|
||||
|
||||
class ProblemTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
a = np.array([1, 1, 1])
|
||||
b = np.array([1, 2])
|
||||
c = np.array([1, 4])
|
||||
self.mesh2 = Mesh.TensorMesh([a, b], np.array([3, 5]))
|
||||
self.p2 = Problem.BaseProblem(self.mesh2, None)
|
||||
self.reg = Regularization.BaseRegularization(self.mesh2)
|
||||
|
||||
def test_regularization(self):
|
||||
derChk = lambda m: [self.reg.modelObj(m), self.reg.modelObjDeriv(m)]
|
||||
mSynth = np.random.randn(self.mesh2.nC)
|
||||
checkDerivative(derChk, mSynth, plotIt=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from SimPEG import *
|
||||
from TestUtils import checkDerivative
|
||||
from scipy.sparse.linalg import dsolve
|
||||
import inspect
|
||||
|
||||
|
||||
class RegularizationTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mesh2 = Mesh.TensorMesh([3, 2])
|
||||
|
||||
def test_regularization(self):
|
||||
for R in dir(Regularization):
|
||||
r = getattr(Regularization, R)
|
||||
if not inspect.isclass(r): continue
|
||||
if not issubclass(r, Regularization.BaseRegularization):
|
||||
continue
|
||||
# if 'Regularization' not in R: continue
|
||||
print 'Check:', R
|
||||
model = r.modelPair(self.mesh2)
|
||||
reg = r(self.mesh2, model)
|
||||
m = model.example()
|
||||
passed = checkDerivative(lambda m : [reg.modelObj(m), reg.modelObjDeriv(m)], m, plotIt=False)
|
||||
self.assertTrue(passed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user