diff --git a/SimPEG/inverse/Inversion.py b/SimPEG/inverse/Inversion.py index 160678aa..a2ef2513 100644 --- a/SimPEG/inverse/Inversion.py +++ b/SimPEG/inverse/Inversion.py @@ -85,25 +85,19 @@ class BaseInversion(object): if self.stoppingCriteria(): break self.printDone() + self.finish() + return self.m + @callHooks('startup') def startup(self, m0): """ **startup** is called at the start of any new run call. - If you have things that also need to run on startup, you can create a method:: - - def _startup*(self, x0): - pass - - Where the * can be any string. If present, _startup* will be called at the start of the default startup call. - You may also completely overwrite this function. - :param numpy.ndarray x0: initial x :rtype: None :return: None """ - callHooks(self,'startup',m0) if not hasattr(self.reg, '_mref'): print 'Regularization has not set mref. SimPEG will set it to m0.' @@ -115,43 +109,25 @@ class BaseInversion(object): self.phi_d_last = np.nan self.phi_m_last = np.nan + @callHooks('doStartIteration') def doStartIteration(self): """ **doStartIteration** is called at the end of each run iteration. - If you have things that also need to run at the end of every iteration, you can create a method:: - - def _doStartIteration*(self): - pass - - Where the * can be any string. If present, _doStartIteration* will be called at the start of the default doStartIteration call. - You may also completely overwrite this function. - :rtype: None :return: None """ - callHooks(self,'doStartIteration') - self._beta = self.getBeta() + @callHooks('doEndIteration') def doEndIteration(self): """ **doEndIteration** is called at the end of each run iteration. - If you have things that also need to run at the end of every iteration, you can create a method:: - - def _doEndIteration*(self): - pass - - Where the * can be any string. If present, _doEndIteration* will be called at the start of the default doEndIteration call. - You may also completely overwrite this function. - :rtype: None :return: None """ - callHooks(self,'doEndIteration') - # store old values self.phi_d_last = self.phi_d self.phi_m_last = self.phi_m @@ -213,6 +189,14 @@ class BaseInversion(object): """ printStoppers(self, self.stoppers) + @callHooks('finish') + def finish(self): + """finish() + + **finish** is called at the end of the optimization. + """ + pass + @timeIt def evalFunction(self, m, return_g=True, return_H=True): """evalFunction(m, return_g=True, return_H=True) diff --git a/SimPEG/inverse/Optimize.py b/SimPEG/inverse/Optimize.py index 6edacabd..80a03632 100644 --- a/SimPEG/inverse/Optimize.py +++ b/SimPEG/inverse/Optimize.py @@ -155,6 +155,7 @@ class Minimize(object): doEndIteration(xt) printDone() + finish() return xc """ self.evalFunction = evalFunction @@ -175,6 +176,7 @@ class Minimize(object): self.doEndIteration(xt) self.printDone() + self.finish() return self.xc @@ -188,6 +190,7 @@ class Minimize(object): def parent(self, value): self._parent = value + @callHooks('startup') def startup(self, x0): """ **startup** is called at the start of any new minimize call. @@ -198,19 +201,10 @@ class Minimize(object): xc = x0 _iter = _iterLS = 0 - If you have things that also need to run on startup, you can create a method:: - - def _startup*(self, x0): - pass - - Where the * can be any string. If present, _startup* will be called at the start of the default startup call. - You may also completely overwrite this function. - :param numpy.ndarray x0: initial x :rtype: None :return: None """ - callHooks(self,'startup',x0) self._iter = 0 self._iterLS = 0 @@ -222,6 +216,7 @@ class Minimize(object): self.x_last = x0 @count + @callHooks('doStartIteration') def doStartIteration(self): """doStartIteration() @@ -230,7 +225,8 @@ class Minimize(object): :rtype: None :return: None """ - callHooks(self,'doStartIteration') + pass + def printInit(self, inLS=False): """ @@ -244,6 +240,7 @@ class Minimize(object): name = self.name if not inLS else self.nameLS printTitles(self, self.printers if not inLS else self.printersLS, name, pad) + @callHooks('printIter') def printIter(self, inLS=False): """ **printIter** is called directly after function evaluations. @@ -252,8 +249,6 @@ class Minimize(object): parent.printIter function and call that. """ - callHooks(self,'printIter',inLS) - pad = ' '*10 if inLS else '' printLine(self, self.printers if not inLS else self.printersLS, pad=pad) @@ -270,6 +265,11 @@ class Minimize(object): stoppers = self.stoppers if not inLS else self.stoppersLS printStoppers(self, stoppers, pad='', stop=stop, done=done) + + def finish(self): + pass + + def stoppingCriteria(self, inLS=False): if self._iter == 0: self.f0 = self.f @@ -277,6 +277,7 @@ class Minimize(object): return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS) @timeIt + @callHooks('projection') def projection(self, p): """projection(p) @@ -288,7 +289,6 @@ class Minimize(object): :rtype: numpy.ndarray :return: p, projected search direction """ - callHooks(self,'projection',p) return p @timeIt @@ -402,6 +402,7 @@ class Minimize(object): return p, False @count + @callHooks('doEndIteration') def doEndIteration(self, xt): """doEndIteration(xt) @@ -411,21 +412,10 @@ class Minimize(object): self.xc must be updated in this code. - - If you have things that also need to run at the end of every iteration, you can create a method:: - - def _doEndIteration*(self, xt): - pass - - Where the * can be any string. If present, _doEndIteration* will be called at the start of the default doEndIteration call. - You may also completely overwrite this function. - :param numpy.ndarray xt: tested new iterate that ensures a descent direction. :rtype: None :return: None """ - callHooks(self,'doEndIteration',xt) - # store old values self.f_last = self.f self.x_last, self.xc = self.xc, xt @@ -630,7 +620,7 @@ class ProjectedGradient(Minimize, Remember): if self.debug: print 'doEndIteration.ProjGrad, f_current_decrease: ', f_current_decrease if self.debug: print 'doEndIteration.ProjGrad, f_decrease_max: ', self.f_decrease_max - if self.debug: print 'doEndIteration.ProjGrad, stopDoingSD: ', self.stopDoingSD + if self.debug: print 'doEndIteration.ProjGrad, stopDoingSD: ', self.stopDoingPG class BFGS(Minimize, Remember): diff --git a/SimPEG/tests/test_optimizers.py b/SimPEG/tests/test_optimizers.py index 654d1804..0253852c 100644 --- a/SimPEG/tests/test_optimizers.py +++ b/SimPEG/tests/test_optimizers.py @@ -32,7 +32,7 @@ class TestOptimizers(unittest.TestCase): self.assertTrue(np.linalg.norm(xopt-x_true,2) < TOL, True) def test_ProjGradient_quadraticBounded(self): - PG = inverse.ProjectedGradient() + PG = inverse.ProjectedGradient(debug=True) PG.lower, PG.upper = -2, 2 xopt = PG.minimize(getQuadratic(self.A,self.b),np.array([0,0])) x_true = np.array([2.,2.]) diff --git a/SimPEG/utils/__init__.py b/SimPEG/utils/__init__.py index 622557d6..19aecec6 100644 --- a/SimPEG/utils/__init__.py +++ b/SimPEG/utils/__init__.py @@ -23,7 +23,7 @@ def hook(obj, method, name=None, overwrite=False, silent=False): If name is None, the name of the method is used. """ - if name is None: + if name is None: name = method.__name__ if name == '': raise Exception('Must provide name to hook lambda functions.') @@ -87,11 +87,39 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'): print pad + stopper['str'] % (l<=r,l,r) print pad + "%s%s%s" % ('-'*25,done,'-'*25) -def callHooks(obj, match, *args, **kwargs): - for method in [posible for posible in dir(obj) if ('_'+match) in posible]: - if getattr(obj,'debug',False): print (match+' is calling self.'+method) - getattr(obj,method)(*args, **kwargs) +def callHooks(match): + """ + Use this to wrap a funciton:: + @callHooks('doEndIteration') + def doEndIteration(self): + pass + + This will call everything named _doEndIteration* at the beginning of the function call. + """ + def callHooksWrap(f): + @wraps(f) + def wrapper(self,*args,**kwargs): + + for method in [posible for posible in dir(self) if ('_'+match) in posible]: + if getattr(self,'debug',False): print (match+' is calling self.'+method) + getattr(self,method)(*args, **kwargs) + + return f(self,*args,**kwargs) + + extra = """ + If you have things that also need to run in the method %s, you can create a method:: + + def _%s*(self, ... ): + pass + + Where the * can be any string. If present, _%s* will be called at the start of the default %s call. + You may also completely overwrite this function. + """ % (match, match, match, match) + + wrapper.__doc__ += extra + return wrapper + return callHooksWrap class Counter(object):