mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-02 09:21:32 +08:00
Parameter function for Beta instead of a scalar. Bug Fixes.
This commit is contained in:
+6
-65
@@ -8,7 +8,6 @@ class BaseInversion(object):
|
||||
|
||||
__metaclass__ = Utils.Save.Savable
|
||||
|
||||
maxIter = 1 #: Maximum number of iterations
|
||||
name = 'BaseInversion'
|
||||
|
||||
debug = False #: Print debugging information
|
||||
@@ -16,11 +15,11 @@ class BaseInversion(object):
|
||||
comment = '' #: Used by some functions to indicate what is going on in the algorithm
|
||||
counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things
|
||||
|
||||
def __init__(self, dataObj, opt, **kwargs):
|
||||
def __init__(self, objFunc, opt, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
|
||||
self.dataObj = dataObj
|
||||
self.dataObj.parent = self
|
||||
self.objFunc = objFunc
|
||||
self.objFunc.parent = self
|
||||
|
||||
self.opt = opt
|
||||
self.opt.parent = self
|
||||
@@ -36,7 +35,7 @@ class BaseInversion(object):
|
||||
if not hasattr(opt, '_bfgsH0') and hasattr(opt, 'bfgsH0'): # Check if it has been set by the user and the default is not being used.
|
||||
#TODO: I don't think that this if statement is working...
|
||||
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv. Done using direct methods.'
|
||||
opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv())
|
||||
opt.bfgsH0 = SimPEG.Solver(objFunc.reg.modelObj2Deriv())
|
||||
|
||||
|
||||
@property
|
||||
@@ -63,75 +62,17 @@ class BaseInversion(object):
|
||||
Runs the inversion!
|
||||
|
||||
"""
|
||||
self.startup(m0)
|
||||
while True:
|
||||
self.doStartIteration()
|
||||
self.m = self.opt.minimize(self.evalFunction, self.m)
|
||||
self.doEndIteration()
|
||||
if self.stoppingCriteria(): break
|
||||
|
||||
self.printDone()
|
||||
self.objFunc.startup(m0)
|
||||
self.m = self.opt.minimize(self.objFunc.evalFunction, m0)
|
||||
self.finish()
|
||||
|
||||
return self.m
|
||||
|
||||
@Utils.callHooks('startup')
|
||||
def startup(self, m0):
|
||||
"""
|
||||
**startup** is called at the start of any new run call.
|
||||
|
||||
:param numpy.ndarray x0: initial x
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if not hasattr(self.reg, '_mref'):
|
||||
print 'Regularization has not set mref. SimPEG will set it to m0.'
|
||||
self.reg.mref = m0
|
||||
|
||||
self.m = m0
|
||||
self._iter = 0
|
||||
self._beta = None
|
||||
self.phi_d_last = np.nan
|
||||
self.phi_m_last = np.nan
|
||||
|
||||
@Utils.callHooks('doStartIteration')
|
||||
def doStartIteration(self):
|
||||
"""
|
||||
**doStartIteration** is called at the end of each run iteration.
|
||||
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
self._beta = self.getBeta()
|
||||
|
||||
|
||||
@Utils.callHooks('doEndIteration')
|
||||
def doEndIteration(self):
|
||||
"""
|
||||
**doEndIteration** is called at the end of each run iteration.
|
||||
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
# store old values
|
||||
self.phi_d_last = self.phi_d
|
||||
self.phi_m_last = self.phi_m
|
||||
self._iter += 1
|
||||
|
||||
|
||||
def stoppingCriteria(self):
|
||||
if self.debug: print 'checking stoppingCriteria'
|
||||
return Utils.checkStoppers(self, self.stoppers)
|
||||
|
||||
|
||||
def printDone(self):
|
||||
"""
|
||||
**printDone** is called at the end of the inversion routine.
|
||||
|
||||
"""
|
||||
Utils.printStoppers(self, self.stoppers)
|
||||
|
||||
@Utils.callHooks('finish')
|
||||
def finish(self):
|
||||
"""finish()
|
||||
|
||||
+38
-16
@@ -1,16 +1,16 @@
|
||||
from SimPEG import Utils
|
||||
from SimPEG import Utils, np, sp
|
||||
|
||||
class BaseObjFunction(object):
|
||||
"""docstring for BaseObjFunction"""
|
||||
|
||||
__metaclass__ = Utils.Save.Savable
|
||||
|
||||
beta = None #: Regularization trade-off parameter
|
||||
beta = None #: Regularization trade-off parameter
|
||||
debug = False #: Print debugging information
|
||||
counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things
|
||||
|
||||
name = 'BaseObjFunction' #: Name of the objective function
|
||||
|
||||
counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things
|
||||
|
||||
u_current = None #: The most current evaluated field
|
||||
m_current = None #: The most current model
|
||||
|
||||
@@ -31,6 +31,24 @@ class BaseObjFunction(object):
|
||||
self.data = data
|
||||
self.reg = reg
|
||||
|
||||
|
||||
@Utils.callHooks('startup')
|
||||
def startup(self, m0):
|
||||
"""startup(m0)
|
||||
|
||||
Called when inversion is first starting.
|
||||
"""
|
||||
if self.debug: print 'Calling ObjFunction.startup'
|
||||
|
||||
if not hasattr(self.reg, '_mref'):
|
||||
print 'Regularization has not set mref. SimPEG will set it to m0.'
|
||||
self.reg.mref = m0
|
||||
|
||||
self.phi_d = np.nan
|
||||
self.phi_m = np.nan
|
||||
|
||||
self.m_current = m0
|
||||
|
||||
@Utils.timeIt
|
||||
def evalFunction(self, m, return_g=True, return_H=True):
|
||||
"""evalFunction(m, return_g=True, return_H=True)
|
||||
@@ -42,15 +60,15 @@ class BaseObjFunction(object):
|
||||
u = self.data.prob.field(m)
|
||||
self.u_current = u
|
||||
|
||||
if self._iter is 0 and self._beta is None:
|
||||
self._beta = self.beta0 = self.estimateBeta0(u=u,ratio=self.beta0_ratio)
|
||||
|
||||
phi_d = self.dataObj(m, u=u)
|
||||
phi_m = self.reg.modelObj(m)
|
||||
|
||||
self.dpred = self.data.dpred(m, u=u) # This is a cheap matrix vector calculation.
|
||||
self.phi_d = phi_d
|
||||
self.phi_m = phi_m
|
||||
|
||||
self.phi_d, self.phi_d_last = phi_d, self.phi_d
|
||||
self.phi_m, self.phi_m_last = phi_m, self.phi_m
|
||||
|
||||
self._beta = self.beta.get() #TODO: This needs to be fixed.
|
||||
|
||||
f = phi_d + self._beta * phi_m
|
||||
|
||||
@@ -195,25 +213,28 @@ class BetaSchedule(Utils.Parameter):
|
||||
|
||||
beta = None #: Beta parameter
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Utils.Parameter.__init__(self, *args, **kwargs)
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
|
||||
def initialize(self):
|
||||
self.beta = self.beta0
|
||||
|
||||
@Utils.requires('parent')
|
||||
def get(self):
|
||||
def nextIter(self):
|
||||
if self.beta is 'guess':
|
||||
if self.debug: print 'BetaSchedule is estimating Beta0.'
|
||||
self.beta = self.estimateBeta0()
|
||||
|
||||
invesion = self.parent.parent
|
||||
if inversion._iter > 0 and inversion._iter % self.coolingRate == 0:
|
||||
opt = self.parent.parent.opt
|
||||
if opt._iter > 0 and opt._iter % self.coolingRate == 0:
|
||||
if self.debug: print 'BetaSchedule is cooling Beta. Iteration: %d' % opt._iter
|
||||
self.beta /= self.coolingFactor
|
||||
|
||||
return self.beta
|
||||
|
||||
@Utils.requires('parent')
|
||||
def estimateBeta0(self, u=None):
|
||||
def estimateBeta0(self):
|
||||
"""estimateBeta0(u=None)
|
||||
|
||||
The initial beta is calculated by comparing the estimated
|
||||
@@ -246,10 +267,11 @@ class BetaSchedule(Utils.Parameter):
|
||||
:rtype: float
|
||||
:return: beta0
|
||||
"""
|
||||
objFunc =
|
||||
objFunc = self.parent
|
||||
data = objFunc.data
|
||||
|
||||
m = invesion.m
|
||||
m = objFunc.m_current
|
||||
u = objFunc.u_current
|
||||
|
||||
if u is None:
|
||||
u = data.prob.field(m)
|
||||
|
||||
@@ -71,9 +71,9 @@ class IterationPrinters(object):
|
||||
bSet = {"title": "bSet", "value": lambda M: np.sum(M.bindingSet(M.xc)), "width": 8, "format": "%d"}
|
||||
comment = {"title": "Comment", "value": lambda M: M.comment, "width": 12, "format": "%s"}
|
||||
|
||||
beta = {"title": "beta", "value": lambda M: M.parent._beta, "width": 10, "format": "%1.2e"}
|
||||
phi_d = {"title": "phi_d", "value": lambda M: M.parent.phi_d, "width": 10, "format": "%1.2e"}
|
||||
phi_m = {"title": "phi_m", "value": lambda M: M.parent.phi_m, "width": 10, "format": "%1.2e"}
|
||||
beta = {"title": "beta", "value": lambda M: M.parent.objFunc.beta.get(), "width": 10, "format": "%1.2e"}
|
||||
phi_d = {"title": "phi_d", "value": lambda M: M.parent.objFunc.phi_d, "width": 10, "format": "%1.2e"}
|
||||
phi_m = {"title": "phi_m", "value": lambda M: M.parent.objFunc.phi_m, "width": 10, "format": "%1.2e"}
|
||||
|
||||
|
||||
class Minimize(object):
|
||||
|
||||
@@ -270,7 +270,10 @@ def timeIt(f):
|
||||
class Parameter(object):
|
||||
"""Parameter"""
|
||||
|
||||
debug = False #: Print debugging information
|
||||
debug = False #: Print debugging information
|
||||
|
||||
current = None #: This hold
|
||||
currentIter = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -285,10 +288,21 @@ class Parameter(object):
|
||||
print 'Parameter has switched to a new parent!'
|
||||
self._parent = p
|
||||
|
||||
@property
|
||||
def opt(self):
|
||||
return self.parent.parent.opt
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def get(self):
|
||||
if (self.current is None or
|
||||
not self.opt._iter == self.currentIter):
|
||||
self.current = self.nextIter()
|
||||
self.currentIter = self.opt._iter
|
||||
return self.current
|
||||
|
||||
def nextIter(self):
|
||||
raise NotImplementedError('Getting the Parameter is not yet implemented.')
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user