mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-03 04:20:35 +08:00
Simplifications to Regularization code.
This commit is contained in:
@@ -85,13 +85,13 @@ class Regularization(object):
|
||||
|
||||
__metaclass__ = utils.Save.Savable
|
||||
|
||||
alpha_s = 1e-6 #: Smallness weight
|
||||
alpha_x = 1.0 #: Weight for the first derivative in the x direction
|
||||
alpha_y = 1.0 #: Weight for the first derivative in the y direction
|
||||
alpha_z = 1.0 #: Weight for the first derivative in the z direction
|
||||
alpha_xx = 0.0 #: Weight for the second derivative in the x direction
|
||||
alpha_yy = 0.0 #: Weight for the second derivative in the y direction
|
||||
alpha_zz = 0.0 #: Weight for the second derivative in the z direction
|
||||
alpha_s = utils.dependentProperty('_alpha_s', 1e-6, ['_W', '_Ws'], "Smallness weight")
|
||||
alpha_x = utils.dependentProperty('_alpha_x', 1.0, ['_W', '_Wx'], "Weight for the first derivative in the x direction")
|
||||
alpha_y = utils.dependentProperty('_alpha_y', 1.0, ['_W', '_Wy'], "Weight for the first derivative in the y direction")
|
||||
alpha_z = utils.dependentProperty('_alpha_z', 1.0, ['_W', '_Wz'], "Weight for the first derivative in the z direction")
|
||||
alpha_xx = utils.dependentProperty('_alpha_xx', 0.0, ['_W', '_Wxx'], "Weight for the second derivative in the x direction")
|
||||
alpha_yy = utils.dependentProperty('_alpha_yy', 0.0, ['_W', '_Wyy'], "Weight for the second derivative in the y direction")
|
||||
alpha_zz = utils.dependentProperty('_alpha_zz', 0.0, ['_W', '_Wzz'], "Weight for the second derivative in the z direction")
|
||||
|
||||
counter = None
|
||||
|
||||
@@ -110,124 +110,110 @@ class Regularization(object):
|
||||
|
||||
@property
|
||||
def Ws(self):
|
||||
"""Regularization matrix Ws"""
|
||||
if getattr(self,'_Ws', None) is None:
|
||||
self._Ws = utils.sdiag(self.mesh.vol**0.5)
|
||||
self._Ws = utils.sdiag((self.mesh.vol*self.alpha_s)**0.5)
|
||||
return self._Ws
|
||||
|
||||
@property
|
||||
def Wx(self):
|
||||
"""Regularization matrix Wx"""
|
||||
if getattr(self, '_Wx', None) is None:
|
||||
Ave_x_vol = self.mesh.aveF2CC[:,:self.mesh.nFv[0]].T*self.mesh.vol
|
||||
self._Wx = utils.sdiag(Ave_x_vol**0.5)*self.mesh.cellGradx
|
||||
self._Wx = utils.sdiag((Ave_x_vol*self.alpha_x)**0.5)*self.mesh.cellGradx
|
||||
return self._Wx
|
||||
|
||||
@property
|
||||
def Wy(self):
|
||||
"""Regularization matrix Wy"""
|
||||
if getattr(self, '_Wy', None) is None:
|
||||
Ave_y_vol = self.mesh.aveF2CC[:,self.mesh.nFv[0]:np.sum(self.mesh.nFv[:2])].T*self.mesh.vol
|
||||
self._Wy = utils.sdiag(Ave_y_vol**0.5)*self.mesh.cellGrady
|
||||
self._Wy = utils.sdiag((Ave_y_vol*self.alpha_y)**0.5)*self.mesh.cellGrady
|
||||
return self._Wy
|
||||
|
||||
@property
|
||||
def Wz(self):
|
||||
"""Regularization matrix Wz"""
|
||||
if getattr(self, '_Wz', None) is None:
|
||||
Ave_z_vol = self.mesh.aveF2CC[:,np.sum(self.mesh.nFv[:2]):].T*self.mesh.vol
|
||||
self._Wz = utils.sdiag(Ave_z_vol**0.5)*self.mesh.cellGradz
|
||||
self._Wz = utils.sdiag((Ave_z_vol*self.alpha_z)**0.5)*self.mesh.cellGradz
|
||||
return self._Wz
|
||||
|
||||
@property
|
||||
def Wxx(self):
|
||||
"""Regularization matrix Wxx"""
|
||||
if getattr(self, '_Wxx', None) is None:
|
||||
self._Wxx = utils.sdiag(self.mesh.vol**0.5)*self.mesh.faceDivx*self.mesh.cellGradx
|
||||
self._Wxx = utils.sdiag((self.mesh.vol*self.alpha_xx)**0.5)*self.mesh.faceDivx*self.mesh.cellGradx
|
||||
return self._Wxx
|
||||
|
||||
@property
|
||||
def Wyy(self):
|
||||
"""Regularization matrix Wyy"""
|
||||
if getattr(self, '_Wyy', None) is None:
|
||||
self._Wyy = utils.sdiag(self.mesh.vol**0.5)*self.mesh.faceDivy*self.mesh.cellGrady
|
||||
self._Wyy = utils.sdiag((self.mesh.vol*self.alpha_yy)**0.5)*self.mesh.faceDivy*self.mesh.cellGrady
|
||||
return self._Wyy
|
||||
|
||||
@property
|
||||
def Wzz(self):
|
||||
"""Regularization matrix Wzz"""
|
||||
if getattr(self, '_Wzz', None) is None:
|
||||
self._Wzz = utils.sdiag(self.mesh.vol**0.5)*self.mesh.faceDivz*self.mesh.cellGradz
|
||||
self._Wzz = utils.sdiag((self.mesh.vol*self.alpha_zz)**0.5)*self.mesh.faceDivz*self.mesh.cellGradz
|
||||
return self._Wzz
|
||||
|
||||
|
||||
def pnorm(self, r):
|
||||
return 0.5*r.dot(r)
|
||||
@property
|
||||
def W(self):
|
||||
"""Full regularization matrix W"""
|
||||
if getattr(self, '_W', None) is None:
|
||||
wlist = (self.Ws, self.Wx, self.Wxx)
|
||||
if self.mesh.dim > 1:
|
||||
wlist += (self.Wy, self.Wyy)
|
||||
if self.mesh.dim > 2:
|
||||
wlist += (self.Wz, self.Wzz)
|
||||
self._W = sp.vstack(wlist)
|
||||
return self._W
|
||||
|
||||
|
||||
@utils.timeIt
|
||||
def modelObj(self, m):
|
||||
mresid = m - self.mref
|
||||
|
||||
mobj = self.alpha_s * self.pnorm( self.Ws * mresid )
|
||||
|
||||
mobj += self.alpha_x * self.pnorm( self.Wx * mresid )
|
||||
mobj += self.alpha_xx * self.pnorm( self.Wxx * mresid )
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobj += self.alpha_y * self.pnorm( self.Wy * mresid )
|
||||
mobj += self.alpha_yy * self.pnorm( self.Wyy * mresid )
|
||||
if self.mesh.dim > 2:
|
||||
mobj += self.alpha_z * self.pnorm( self.Wz * mresid )
|
||||
mobj += self.alpha_zz * self.pnorm( self.Wzz * mresid )
|
||||
|
||||
return mobj
|
||||
r = self.W * (m - self.mref)
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
@utils.timeIt
|
||||
def modelObjDeriv(self, m):
|
||||
"""
|
||||
|
||||
In 1D:
|
||||
The regularization is:
|
||||
|
||||
.. math::
|
||||
|
||||
m_{\\text{obj}} = {1 \over 2}\\alpha_s \left\| W_s (m- m_{\\text{ref}})\\right\|^2_2
|
||||
+ {1 \over 2}\\alpha_x \left\| W_x (m- m_{\\text{ref}})\\right\|^2_2
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
\\frac{ \partial m_{\\text{obj}} }{\partial m} =
|
||||
\\alpha_s W_s^{\\top} W_s (m - m_{\\text{ref}}) +
|
||||
\\alpha_x W_x^{\\top} W_x (m - m_{\\text{ref}})
|
||||
So the derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{ \partial^2 m_{\\text{obj}} }{\partial m^2} =
|
||||
\\alpha_s W_s^{\\top} W_s +
|
||||
\\alpha_x W_x^{\\top} W_x
|
||||
R(m) = \mathbf{W^\\top W (m-m_\\text{ref})}
|
||||
|
||||
"""
|
||||
|
||||
mresid = m - self.mref
|
||||
|
||||
mobjDeriv = self.alpha_s * self.Ws.T * ( self.Ws * mresid)
|
||||
|
||||
mobjDeriv = mobjDeriv + self.alpha_x * self.Wx.T * ( self.Wx * mresid)
|
||||
mobjDeriv = mobjDeriv + self.alpha_xx * self.Wxx.T * ( self.Wxx * mresid)
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobjDeriv = mobjDeriv + self.alpha_y * self.Wy.T * ( self.Wy * mresid)
|
||||
mobjDeriv = mobjDeriv + self.alpha_yy * self.Wyy.T * ( self.Wyy * mresid)
|
||||
if self.mesh.dim > 2:
|
||||
mobjDeriv = mobjDeriv + self.alpha_z * self.Wz.T * ( self.Wz * mresid)
|
||||
mobjDeriv = mobjDeriv + self.alpha_zz * self.Wzz.T * ( self.Wzz * mresid)
|
||||
|
||||
return mobjDeriv
|
||||
|
||||
return self.W.T * ( self.W * (m - self.mref) )
|
||||
|
||||
@utils.timeIt
|
||||
def modelObj2Deriv(self):
|
||||
"""
|
||||
|
||||
mobj2Deriv = self.alpha_s * self.Ws.T * self.Ws
|
||||
The regularization is:
|
||||
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_x * self.Wx.T * self.Wx
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_xx * self.Wxx.T * self.Wxx
|
||||
.. math::
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_y * self.Wy.T * self.Wy
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_yy * self.Wyy.T * self.Wyy
|
||||
if self.mesh.dim > 2:
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_z * self.Wz.T * self.Wz
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_zz * self.Wzz.T * self.Wzz
|
||||
R(m) = \\frac{1}{2}\mathbf{(m-m_\\text{ref})^\\top W^\\top W(m-m_\\text{ref})}
|
||||
|
||||
return mobj2Deriv
|
||||
So the second derivative is straight forward:
|
||||
|
||||
.. math::
|
||||
|
||||
R(m) = \mathbf{W^\\top W}
|
||||
|
||||
"""
|
||||
return self.W.T * self.W
|
||||
|
||||
|
||||
@@ -124,6 +124,15 @@ def callHooks(match):
|
||||
return wrapper
|
||||
return callHooksWrap
|
||||
|
||||
def dependentProperty(name, value, children, doc):
|
||||
def fget(self): return getattr(self,name,value)
|
||||
def fset(self, val):
|
||||
for child in children:
|
||||
if hasattr(self, child):
|
||||
delattr(self, child)
|
||||
setattr(self, name, val)
|
||||
return property(fget=fget, fset=fset, doc=doc)
|
||||
|
||||
|
||||
class Counter(object):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user