mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-02 07:39:13 +08:00
TimeSteppingInversion and Estimate Initial Beta based on eigenvalues comparison.
This commit is contained in:
+67
-13
@@ -4,16 +4,19 @@ import SimPEG
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt, callHooks
|
||||
from Optimize import Remember
|
||||
from BetaSchedule import Cooling
|
||||
from SimPEG.inverse import IterationPrinters, StoppingCriteria
|
||||
|
||||
class BaseInversion(object):
|
||||
"""docstring for BaseInversion"""
|
||||
|
||||
maxIter = 1
|
||||
maxIter = 1 #: Maximum number of iterations
|
||||
name = 'BaseInversion'
|
||||
debug = False
|
||||
beta0 = 1e4
|
||||
|
||||
counter = None
|
||||
debug = False #: Print debugging information
|
||||
|
||||
comment = '' #: Used by some functions to indicate what is going on in the algorithm
|
||||
counter = None #: Set this to a SimPEG.utils.Counter() if you want to count things
|
||||
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
setKwargs(self, **kwargs)
|
||||
@@ -22,14 +25,13 @@ class BaseInversion(object):
|
||||
self.opt = opt
|
||||
self.opt.parent = self
|
||||
|
||||
self.stoppers = [SimPEG.inverse.StoppingCriteria.iteration, SimPEG.inverse.StoppingCriteria.phi_d_target_Inversion]
|
||||
self.stoppers = [StoppingCriteria.iteration]
|
||||
|
||||
# Check if we have inserted printers into the optimization
|
||||
if not np.any([p is SimPEG.inverse.IterationPrinters.phi_d for p in self.opt.printers]):
|
||||
self.opt.printers.insert(1,SimPEG.inverse.IterationPrinters.beta)
|
||||
self.opt.printers.insert(2,SimPEG.inverse.IterationPrinters.phi_d)
|
||||
self.opt.printers.insert(3,SimPEG.inverse.IterationPrinters.phi_m)
|
||||
self.opt.stoppers.append(SimPEG.inverse.StoppingCriteria.phi_d_target_Minimize)
|
||||
if IterationPrinters.phi_d not in self.opt.printers:
|
||||
self.opt.printers.insert(1,IterationPrinters.beta)
|
||||
self.opt.printers.insert(2,IterationPrinters.phi_d)
|
||||
self.opt.printers.insert(3,IterationPrinters.phi_m)
|
||||
|
||||
if not hasattr(opt, '_bfgsH0') and hasattr(opt, 'bfgsH0'): # Check if it has been set by the user and the default is not being used.
|
||||
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv. Done using direct methods.'
|
||||
@@ -128,9 +130,29 @@ class BaseInversion(object):
|
||||
self.phi_m_last = self.phi_m
|
||||
self._iter += 1
|
||||
|
||||
@property
|
||||
def beta0(self):
|
||||
if getattr(self,'_beta0',None) is None:
|
||||
self._beta0 = self.estimateBeta0()
|
||||
return self._beta0
|
||||
@beta0.setter
|
||||
def beta0(self, value):
|
||||
self._beta0 = value
|
||||
|
||||
def getBeta(self):
|
||||
return self.beta0
|
||||
|
||||
def estimateBeta0(self):
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
u = self.prob.field(self.m)
|
||||
v = np.random.rand(*self.m.shape)
|
||||
t = v.dot(self.dataObj2Deriv(self.m,v,u=u))
|
||||
b = v.dot(self.reg.modelObj2Deriv()*v)
|
||||
return 0.1*(t/b)
|
||||
|
||||
def stoppingCriteria(self):
|
||||
if self.debug: print 'checking stoppingCriteria'
|
||||
return checkStoppers(self, self.stoppers)
|
||||
@@ -170,7 +192,7 @@ class BaseInversion(object):
|
||||
|
||||
return phi_d2Deriv + self._beta * phi_m2Deriv
|
||||
|
||||
operator = sp.linalg.LinearOperator( (m.size, m.size), H_fun, dtype=float )
|
||||
operator = sp.linalg.LinearOperator( (m.size, m.size), H_fun, dtype=m.dtype )
|
||||
out += (operator,)
|
||||
return out if len(out) > 1 else out[0]
|
||||
|
||||
@@ -239,8 +261,10 @@ class BaseInversion(object):
|
||||
|
||||
@timeIt
|
||||
def dataObj2Deriv(self, m, v, u=None):
|
||||
"""
|
||||
"""dataObj2Deriv(m, v, u=None)
|
||||
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array v: vector to multiply
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: data misfit derivative
|
||||
@@ -277,7 +301,7 @@ class BaseInversion(object):
|
||||
R = self.Wd*self.prob.dataResidual(m, u=u)
|
||||
|
||||
# TODO: abstract to different norms a little cleaner.
|
||||
# \/ it goes here. in l2 it is the identity.
|
||||
# \/ it goes here. in l2 it is the identity.
|
||||
dmisfit = self.prob.Jt_approx(m, self.Wd * self.Wd * self.prob.J_approx(m, v, u=u), u=u)
|
||||
|
||||
return dmisfit
|
||||
@@ -289,3 +313,33 @@ class Inversion(Cooling, Remember, BaseInversion):
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
BaseInversion.__init__(self, prob, reg, opt, **kwargs)
|
||||
|
||||
self.stoppers.append(StoppingCriteria.phi_d_target_Inversion)
|
||||
|
||||
if StoppingCriteria.phi_d_target_Minimize not in self.opt.stoppers:
|
||||
self.opt.stoppers.append(StoppingCriteria.phi_d_target_Minimize)
|
||||
|
||||
class TimeSteppingInversion(Remember, BaseInversion):
|
||||
"""
|
||||
A slightly different view on regularization parameters,
|
||||
let Beta be viewed as 1/dt, and timestep by updating the
|
||||
reference model every optimization iteration.
|
||||
"""
|
||||
maxIter = 1
|
||||
name = "Time-Stepping SimPEG Inversion"
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
BaseInversion.__init__(self, prob, reg, opt, **kwargs)
|
||||
|
||||
self.stoppers.append(StoppingCriteria.phi_d_target_Inversion)
|
||||
|
||||
if StoppingCriteria.phi_d_target_Minimize not in self.opt.stoppers:
|
||||
self.opt.stoppers.append(StoppingCriteria.phi_d_target_Minimize)
|
||||
|
||||
def _startup_TimeSteppingInversion(self, m0):
|
||||
|
||||
def _doEndIteration_updateMref(self, xt):
|
||||
if self.debug: 'Updating the reference model.'
|
||||
self.parent.reg.mref = self.xc
|
||||
|
||||
self.opt.hook(_doEndIteration_updateMref, overwrite=True)
|
||||
|
||||
@@ -474,8 +474,8 @@ class Remember(object):
|
||||
class ProjectedGradient(Minimize, Remember):
|
||||
name = 'Projected Gradient'
|
||||
|
||||
maxIterCG = 10
|
||||
tolCG = 1e-3
|
||||
maxIterCG = 5
|
||||
tolCG = 1e-1
|
||||
|
||||
lower = -np.inf
|
||||
upper = np.inf
|
||||
@@ -716,8 +716,8 @@ class InexactGaussNewton(BFGS, Minimize, Remember):
|
||||
|
||||
name = 'Inexact Gauss Newton'
|
||||
|
||||
maxIterCG = 10
|
||||
tolCG = 1e-3
|
||||
maxIterCG = 5
|
||||
tolCG = 1e-1
|
||||
|
||||
@property
|
||||
def approxHinv(self):
|
||||
|
||||
@@ -26,9 +26,12 @@ def hook(obj, method, name=None, overwrite=False, silent=False):
|
||||
if name is None: name = method.__name__
|
||||
if not hasattr(obj,name) or overwrite:
|
||||
setattr(obj, name, types.MethodType( method, obj ))
|
||||
elif not silent:
|
||||
if getattr(obj,'debug',False):
|
||||
print 'Method '+name+' was added to class.'
|
||||
elif not silent or getattr(obj,'debug',False):
|
||||
print 'Method '+name+' was not overwritten.'
|
||||
|
||||
|
||||
def setKwargs(obj, **kwargs):
|
||||
"""Sets key word arguments (kwargs) that are present in the object, throw an error if they don't exist."""
|
||||
for attr in kwargs:
|
||||
@@ -97,7 +100,7 @@ class Counter(object):
|
||||
|
||||
If you want to use this, import *count* or *timeIt* and use them as decorators on class methods.
|
||||
|
||||
.. ::
|
||||
::
|
||||
|
||||
class MyClass(object):
|
||||
def __init__(self, url):
|
||||
|
||||
Reference in New Issue
Block a user