diff --git a/SimPEG/Inversion.py b/SimPEG/Inversion.py index 63a42bb3..65abb6eb 100644 --- a/SimPEG/Inversion.py +++ b/SimPEG/Inversion.py @@ -8,7 +8,6 @@ class BaseInversion(object): __metaclass__ = Utils.Save.Savable - maxIter = 1 #: Maximum number of iterations name = 'BaseInversion' debug = False #: Print debugging information @@ -16,11 +15,11 @@ class BaseInversion(object): comment = '' #: Used by some functions to indicate what is going on in the algorithm counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things - def __init__(self, dataObj, opt, **kwargs): + def __init__(self, objFunc, opt, **kwargs): Utils.setKwargs(self, **kwargs) - self.dataObj = dataObj - self.dataObj.parent = self + self.objFunc = objFunc + self.objFunc.parent = self self.opt = opt self.opt.parent = self @@ -36,7 +35,7 @@ class BaseInversion(object): if not hasattr(opt, '_bfgsH0') and hasattr(opt, 'bfgsH0'): # Check if it has been set by the user and the default is not being used. #TODO: I don't think that this if statement is working... print 'Setting bfgsH0 to the inverse of the modelObj2Deriv. Done using direct methods.' - opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv()) + opt.bfgsH0 = SimPEG.Solver(objFunc.reg.modelObj2Deriv()) @property @@ -63,75 +62,17 @@ class BaseInversion(object): Runs the inversion! """ - self.startup(m0) - while True: - self.doStartIteration() - self.m = self.opt.minimize(self.evalFunction, self.m) - self.doEndIteration() - if self.stoppingCriteria(): break - - self.printDone() + self.objFunc.startup(m0) + self.m = self.opt.minimize(self.objFunc.evalFunction, m0) self.finish() return self.m - @Utils.callHooks('startup') - def startup(self, m0): - """ - **startup** is called at the start of any new run call. - - :param numpy.ndarray x0: initial x - :rtype: None - :return: None - """ - - if not hasattr(self.reg, '_mref'): - print 'Regularization has not set mref. SimPEG will set it to m0.' - self.reg.mref = m0 - - self.m = m0 - self._iter = 0 - self._beta = None - self.phi_d_last = np.nan - self.phi_m_last = np.nan - - @Utils.callHooks('doStartIteration') - def doStartIteration(self): - """ - **doStartIteration** is called at the end of each run iteration. - - :rtype: None - :return: None - """ - self._beta = self.getBeta() - - - @Utils.callHooks('doEndIteration') - def doEndIteration(self): - """ - **doEndIteration** is called at the end of each run iteration. - - :rtype: None - :return: None - """ - # store old values - self.phi_d_last = self.phi_d - self.phi_m_last = self.phi_m - self._iter += 1 - - def stoppingCriteria(self): if self.debug: print 'checking stoppingCriteria' return Utils.checkStoppers(self, self.stoppers) - def printDone(self): - """ - **printDone** is called at the end of the inversion routine. - - """ - Utils.printStoppers(self, self.stoppers) - @Utils.callHooks('finish') def finish(self): """finish() diff --git a/SimPEG/ObjFunction.py b/SimPEG/ObjFunction.py index 38323cd3..c9578c54 100644 --- a/SimPEG/ObjFunction.py +++ b/SimPEG/ObjFunction.py @@ -1,16 +1,16 @@ -from SimPEG import Utils +from SimPEG import Utils, np, sp class BaseObjFunction(object): """docstring for BaseObjFunction""" __metaclass__ = Utils.Save.Savable - beta = None #: Regularization trade-off parameter + beta = None #: Regularization trade-off parameter + debug = False #: Print debugging information + counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things name = 'BaseObjFunction' #: Name of the objective function - counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things - u_current = None #: The most current evaluated field m_current = None #: The most current model @@ -31,6 +31,24 @@ class BaseObjFunction(object): self.data = data self.reg = reg + + @Utils.callHooks('startup') + def startup(self, m0): + """startup(m0) + + Called when inversion is first starting. + """ + if self.debug: print 'Calling ObjFunction.startup' + + if not hasattr(self.reg, '_mref'): + print 'Regularization has not set mref. SimPEG will set it to m0.' + self.reg.mref = m0 + + self.phi_d = np.nan + self.phi_m = np.nan + + self.m_current = m0 + @Utils.timeIt def evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True) @@ -42,15 +60,15 @@ class BaseObjFunction(object): u = self.data.prob.field(m) self.u_current = u - if self._iter is 0 and self._beta is None: - self._beta = self.beta0 = self.estimateBeta0(u=u,ratio=self.beta0_ratio) - phi_d = self.dataObj(m, u=u) phi_m = self.reg.modelObj(m) self.dpred = self.data.dpred(m, u=u) # This is a cheap matrix vector calculation. - self.phi_d = phi_d - self.phi_m = phi_m + + self.phi_d, self.phi_d_last = phi_d, self.phi_d + self.phi_m, self.phi_m_last = phi_m, self.phi_m + + self._beta = self.beta.get() #TODO: This needs to be fixed. f = phi_d + self._beta * phi_m @@ -195,25 +213,28 @@ class BetaSchedule(Utils.Parameter): beta = None #: Beta parameter - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): + Utils.Parameter.__init__(self, *args, **kwargs) Utils.setKwargs(self, **kwargs) def initialize(self): self.beta = self.beta0 @Utils.requires('parent') - def get(self): + def nextIter(self): if self.beta is 'guess': + if self.debug: print 'BetaSchedule is estimating Beta0.' self.beta = self.estimateBeta0() - invesion = self.parent.parent - if inversion._iter > 0 and inversion._iter % self.coolingRate == 0: + opt = self.parent.parent.opt + if opt._iter > 0 and opt._iter % self.coolingRate == 0: + if self.debug: print 'BetaSchedule is cooling Beta. Iteration: %d' % opt._iter self.beta /= self.coolingFactor return self.beta @Utils.requires('parent') - def estimateBeta0(self, u=None): + def estimateBeta0(self): """estimateBeta0(u=None) The initial beta is calculated by comparing the estimated @@ -246,10 +267,11 @@ class BetaSchedule(Utils.Parameter): :rtype: float :return: beta0 """ - objFunc = + objFunc = self.parent data = objFunc.data - m = invesion.m + m = objFunc.m_current + u = objFunc.u_current if u is None: u = data.prob.field(m) diff --git a/SimPEG/Optimization.py b/SimPEG/Optimization.py index 3157b5c1..afb2b2a2 100644 --- a/SimPEG/Optimization.py +++ b/SimPEG/Optimization.py @@ -71,9 +71,9 @@ class IterationPrinters(object): bSet = {"title": "bSet", "value": lambda M: np.sum(M.bindingSet(M.xc)), "width": 8, "format": "%d"} comment = {"title": "Comment", "value": lambda M: M.comment, "width": 12, "format": "%s"} - beta = {"title": "beta", "value": lambda M: M.parent._beta, "width": 10, "format": "%1.2e"} - phi_d = {"title": "phi_d", "value": lambda M: M.parent.phi_d, "width": 10, "format": "%1.2e"} - phi_m = {"title": "phi_m", "value": lambda M: M.parent.phi_m, "width": 10, "format": "%1.2e"} + beta = {"title": "beta", "value": lambda M: M.parent.objFunc.beta.get(), "width": 10, "format": "%1.2e"} + phi_d = {"title": "phi_d", "value": lambda M: M.parent.objFunc.phi_d, "width": 10, "format": "%1.2e"} + phi_m = {"title": "phi_m", "value": lambda M: M.parent.objFunc.phi_m, "width": 10, "format": "%1.2e"} class Minimize(object): diff --git a/SimPEG/Utils/__init__.py b/SimPEG/Utils/__init__.py index ab520ab4..f35ee5a0 100644 --- a/SimPEG/Utils/__init__.py +++ b/SimPEG/Utils/__init__.py @@ -270,7 +270,10 @@ def timeIt(f): class Parameter(object): """Parameter""" - debug = False #: Print debugging information + debug = False #: Print debugging information + + current = None #: This hold + currentIter = 0 def __init__(self, *args, **kwargs): pass @@ -285,10 +288,21 @@ class Parameter(object): print 'Parameter has switched to a new parent!' self._parent = p + @property + def opt(self): + return self.parent.parent.opt + def initialize(self): pass def get(self): + if (self.current is None or + not self.opt._iter == self.currentIter): + self.current = self.nextIter() + self.currentIter = self.opt._iter + return self.current + + def nextIter(self): raise NotImplementedError('Getting the Parameter is not yet implemented.')