mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-28 00:45:10 +08:00
140 lines
4.5 KiB
Python
140 lines
4.5 KiB
Python
import Utils, Survey, Problem, numpy as np, scipy.sparse as sp, gc
|
|
from Utils.SolverUtils import *
|
|
import DataMisfit
|
|
import Regularization
|
|
|
|
|
|
class BaseInvProblem(object):
|
|
"""BaseInvProblem(dmisfit, reg, opt)"""
|
|
|
|
__metaclass__ = Utils.SimPEGMetaClass
|
|
|
|
beta = 1.0 #: Trade-off parameter
|
|
|
|
debug = False #: Print debugging information
|
|
counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things
|
|
|
|
dmisfit = None #: DataMisfit
|
|
reg = None #: Regularization
|
|
opt = None #: Optimization program
|
|
|
|
deleteTheseOnModelUpdate = [] # List of strings, e.g. ['_MeSigma', '_MeSigmaI']
|
|
|
|
@property
|
|
def curModel(self):
|
|
"""
|
|
Sets the current model, and removes dependent properties
|
|
"""
|
|
return getattr(self, '_curModel', None)
|
|
@curModel.setter
|
|
def curModel(self, value):
|
|
if value is self.curModel:
|
|
return # it is the same!
|
|
self._curModel = value
|
|
for prop in self.deleteTheseOnModelUpdate:
|
|
if hasattr(self, prop):
|
|
delattr(self, prop)
|
|
|
|
def __init__(self, dmisfit, reg, opt, **kwargs):
|
|
Utils.setKwargs(self, **kwargs)
|
|
assert isinstance(dmisfit, DataMisfit.BaseDataMisfit), 'dmisfit must be a DataMisfit class.'
|
|
assert isinstance(reg, Regularization.BaseRegularization), 'reg must be a Regularization class.'
|
|
self.dmisfit = dmisfit
|
|
self.reg = reg
|
|
self.opt = opt
|
|
self.prob, self.survey = dmisfit.prob, dmisfit.survey
|
|
#TODO: Remove: (and make iteration printers better!)
|
|
self.opt.parent = self
|
|
|
|
@Utils.callHooks('startup')
|
|
def startup(self, m0):
|
|
"""startup(m0)
|
|
|
|
Called when inversion is first starting.
|
|
"""
|
|
if self.debug: print 'Calling InvProblem.startup'
|
|
|
|
if self.reg.mref is None:
|
|
print 'SimPEG.InvProblem will set Regularization.mref to m0.'
|
|
self.reg.mref = m0
|
|
|
|
self.phi_d = np.nan
|
|
self.phi_m = np.nan
|
|
|
|
self.curModel = m0
|
|
|
|
print """SimPEG.InvProblem is setting bfgsH0 to the inverse of the eval2Deriv.
|
|
***Done using same solver as the problem***"""
|
|
self.opt.bfgsH0 = self.prob.Solver(self.reg.eval2Deriv(self.curModel))
|
|
|
|
@property
|
|
def warmstart(self):
|
|
return getattr(self, '_warmstart', [])
|
|
@warmstart.setter
|
|
def warmstart(self, value):
|
|
assert type(value) is list, 'warmstart must be a list.'
|
|
for v in value:
|
|
assert type(v) is tuple, 'warmstart must be a list of tuples (m, u).'
|
|
assert len(v) == 2, 'warmstart must be a list of tuples (m, u). YOURS IS NOT LENGTH 2!'
|
|
assert isinstance(v[0], np.ndarray), 'first warmstart value must be a model.'
|
|
self._warmstart = value
|
|
|
|
def getFields(self, m, store=False, deleteWarmstart=True):
|
|
u = None
|
|
|
|
for mtest, u_ofmtest in self.warmstart:
|
|
if m is mtest:
|
|
u = u_ofmtest
|
|
if self.debug: print 'InvProb is Warm Starting!'
|
|
break
|
|
|
|
if u is None:
|
|
u = self.prob.fields(m)
|
|
|
|
if deleteWarmstart:
|
|
self.warmstart = []
|
|
if store:
|
|
self.warmstart += [(m,u)]
|
|
|
|
return u
|
|
|
|
@Utils.timeIt
|
|
def evalFunction(self, m, return_g=True, return_H=True):
|
|
"""evalFunction(m, return_g=True, return_H=True)
|
|
"""
|
|
|
|
self.curModel = m
|
|
gc.collect()
|
|
|
|
# Store fields if doing a line-search
|
|
u = self.getFields(m, store=(return_g==False and return_H==False))
|
|
|
|
phi_d = self.dmisfit.eval(m, u=u)
|
|
phi_m = self.reg.eval(m)
|
|
|
|
self.dpred = self.survey.dpred(m, u=u) # This is a cheap matrix vector calculation.
|
|
|
|
self.phi_d, self.phi_d_last = phi_d, self.phi_d
|
|
self.phi_m, self.phi_m_last = phi_m, self.phi_m
|
|
|
|
f = phi_d + self.beta * phi_m
|
|
|
|
out = (f,)
|
|
if return_g:
|
|
phi_dDeriv = self.dmisfit.evalDeriv(m, u=u)
|
|
phi_mDeriv = self.reg.evalDeriv(m)
|
|
|
|
g = phi_dDeriv + self.beta * phi_mDeriv
|
|
out += (g,)
|
|
|
|
if return_H:
|
|
def H_fun(v):
|
|
phi_d2Deriv = self.dmisfit.eval2Deriv(m, v, u=u)
|
|
phi_m2Deriv = self.reg.eval2Deriv(m, v=v)
|
|
|
|
return phi_d2Deriv + self.beta * phi_m2Deriv
|
|
|
|
H = sp.linalg.LinearOperator( (m.size, m.size), H_fun, dtype=m.dtype )
|
|
out += (H,)
|
|
return out if len(out) > 1 else out[0]
|