mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 22:20:58 +08:00
Change callHooks structure.
This commit is contained in:
+13
-29
@@ -85,25 +85,19 @@ class BaseInversion(object):
|
||||
if self.stoppingCriteria(): break
|
||||
|
||||
self.printDone()
|
||||
self.finish()
|
||||
|
||||
return self.m
|
||||
|
||||
@callHooks('startup')
|
||||
def startup(self, m0):
|
||||
"""
|
||||
**startup** is called at the start of any new run call.
|
||||
|
||||
If you have things that also need to run on startup, you can create a method::
|
||||
|
||||
def _startup*(self, x0):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _startup* will be called at the start of the default startup call.
|
||||
You may also completely overwrite this function.
|
||||
|
||||
:param numpy.ndarray x0: initial x
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'startup',m0)
|
||||
|
||||
if not hasattr(self.reg, '_mref'):
|
||||
print 'Regularization has not set mref. SimPEG will set it to m0.'
|
||||
@@ -115,43 +109,25 @@ class BaseInversion(object):
|
||||
self.phi_d_last = np.nan
|
||||
self.phi_m_last = np.nan
|
||||
|
||||
@callHooks('doStartIteration')
|
||||
def doStartIteration(self):
|
||||
"""
|
||||
**doStartIteration** is called at the end of each run iteration.
|
||||
|
||||
If you have things that also need to run at the end of every iteration, you can create a method::
|
||||
|
||||
def _doStartIteration*(self):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _doStartIteration* will be called at the start of the default doStartIteration call.
|
||||
You may also completely overwrite this function.
|
||||
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'doStartIteration')
|
||||
|
||||
self._beta = self.getBeta()
|
||||
|
||||
|
||||
@callHooks('doEndIteration')
|
||||
def doEndIteration(self):
|
||||
"""
|
||||
**doEndIteration** is called at the end of each run iteration.
|
||||
|
||||
If you have things that also need to run at the end of every iteration, you can create a method::
|
||||
|
||||
def _doEndIteration*(self):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _doEndIteration* will be called at the start of the default doEndIteration call.
|
||||
You may also completely overwrite this function.
|
||||
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'doEndIteration')
|
||||
|
||||
# store old values
|
||||
self.phi_d_last = self.phi_d
|
||||
self.phi_m_last = self.phi_m
|
||||
@@ -213,6 +189,14 @@ class BaseInversion(object):
|
||||
"""
|
||||
printStoppers(self, self.stoppers)
|
||||
|
||||
@callHooks('finish')
|
||||
def finish(self):
|
||||
"""finish()
|
||||
|
||||
**finish** is called at the end of the optimization.
|
||||
"""
|
||||
pass
|
||||
|
||||
@timeIt
|
||||
def evalFunction(self, m, return_g=True, return_H=True):
|
||||
"""evalFunction(m, return_g=True, return_H=True)
|
||||
|
||||
+15
-25
@@ -155,6 +155,7 @@ class Minimize(object):
|
||||
doEndIteration(xt)
|
||||
|
||||
printDone()
|
||||
finish()
|
||||
return xc
|
||||
"""
|
||||
self.evalFunction = evalFunction
|
||||
@@ -175,6 +176,7 @@ class Minimize(object):
|
||||
self.doEndIteration(xt)
|
||||
|
||||
self.printDone()
|
||||
self.finish()
|
||||
|
||||
return self.xc
|
||||
|
||||
@@ -188,6 +190,7 @@ class Minimize(object):
|
||||
def parent(self, value):
|
||||
self._parent = value
|
||||
|
||||
@callHooks('startup')
|
||||
def startup(self, x0):
|
||||
"""
|
||||
**startup** is called at the start of any new minimize call.
|
||||
@@ -198,19 +201,10 @@ class Minimize(object):
|
||||
xc = x0
|
||||
_iter = _iterLS = 0
|
||||
|
||||
If you have things that also need to run on startup, you can create a method::
|
||||
|
||||
def _startup*(self, x0):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _startup* will be called at the start of the default startup call.
|
||||
You may also completely overwrite this function.
|
||||
|
||||
:param numpy.ndarray x0: initial x
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'startup',x0)
|
||||
|
||||
self._iter = 0
|
||||
self._iterLS = 0
|
||||
@@ -222,6 +216,7 @@ class Minimize(object):
|
||||
self.x_last = x0
|
||||
|
||||
@count
|
||||
@callHooks('doStartIteration')
|
||||
def doStartIteration(self):
|
||||
"""doStartIteration()
|
||||
|
||||
@@ -230,7 +225,8 @@ class Minimize(object):
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'doStartIteration')
|
||||
pass
|
||||
|
||||
|
||||
def printInit(self, inLS=False):
|
||||
"""
|
||||
@@ -244,6 +240,7 @@ class Minimize(object):
|
||||
name = self.name if not inLS else self.nameLS
|
||||
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
|
||||
|
||||
@callHooks('printIter')
|
||||
def printIter(self, inLS=False):
|
||||
"""
|
||||
**printIter** is called directly after function evaluations.
|
||||
@@ -252,8 +249,6 @@ class Minimize(object):
|
||||
parent.printIter function and call that.
|
||||
|
||||
"""
|
||||
callHooks(self,'printIter',inLS)
|
||||
|
||||
pad = ' '*10 if inLS else ''
|
||||
printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
|
||||
|
||||
@@ -270,6 +265,11 @@ class Minimize(object):
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
printStoppers(self, stoppers, pad='', stop=stop, done=done)
|
||||
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
def stoppingCriteria(self, inLS=False):
|
||||
if self._iter == 0:
|
||||
self.f0 = self.f
|
||||
@@ -277,6 +277,7 @@ class Minimize(object):
|
||||
return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
|
||||
|
||||
@timeIt
|
||||
@callHooks('projection')
|
||||
def projection(self, p):
|
||||
"""projection(p)
|
||||
|
||||
@@ -288,7 +289,6 @@ class Minimize(object):
|
||||
:rtype: numpy.ndarray
|
||||
:return: p, projected search direction
|
||||
"""
|
||||
callHooks(self,'projection',p)
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
@@ -402,6 +402,7 @@ class Minimize(object):
|
||||
return p, False
|
||||
|
||||
@count
|
||||
@callHooks('doEndIteration')
|
||||
def doEndIteration(self, xt):
|
||||
"""doEndIteration(xt)
|
||||
|
||||
@@ -411,21 +412,10 @@ class Minimize(object):
|
||||
|
||||
self.xc must be updated in this code.
|
||||
|
||||
|
||||
If you have things that also need to run at the end of every iteration, you can create a method::
|
||||
|
||||
def _doEndIteration*(self, xt):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _doEndIteration* will be called at the start of the default doEndIteration call.
|
||||
You may also completely overwrite this function.
|
||||
|
||||
:param numpy.ndarray xt: tested new iterate that ensures a descent direction.
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
callHooks(self,'doEndIteration',xt)
|
||||
|
||||
# store old values
|
||||
self.f_last = self.f
|
||||
self.x_last, self.xc = self.xc, xt
|
||||
@@ -630,7 +620,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
|
||||
if self.debug: print 'doEndIteration.ProjGrad, f_current_decrease: ', f_current_decrease
|
||||
if self.debug: print 'doEndIteration.ProjGrad, f_decrease_max: ', self.f_decrease_max
|
||||
if self.debug: print 'doEndIteration.ProjGrad, stopDoingSD: ', self.stopDoingSD
|
||||
if self.debug: print 'doEndIteration.ProjGrad, stopDoingSD: ', self.stopDoingPG
|
||||
|
||||
|
||||
class BFGS(Minimize, Remember):
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestOptimizers(unittest.TestCase):
|
||||
self.assertTrue(np.linalg.norm(xopt-x_true,2) < TOL, True)
|
||||
|
||||
def test_ProjGradient_quadraticBounded(self):
|
||||
PG = inverse.ProjectedGradient()
|
||||
PG = inverse.ProjectedGradient(debug=True)
|
||||
PG.lower, PG.upper = -2, 2
|
||||
xopt = PG.minimize(getQuadratic(self.A,self.b),np.array([0,0]))
|
||||
x_true = np.array([2.,2.])
|
||||
|
||||
@@ -23,7 +23,7 @@ def hook(obj, method, name=None, overwrite=False, silent=False):
|
||||
|
||||
If name is None, the name of the method is used.
|
||||
"""
|
||||
if name is None:
|
||||
if name is None:
|
||||
name = method.__name__
|
||||
if name == '<lambda>':
|
||||
raise Exception('Must provide name to hook lambda functions.')
|
||||
@@ -87,11 +87,39 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'):
|
||||
print pad + stopper['str'] % (l<=r,l,r)
|
||||
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
|
||||
|
||||
def callHooks(obj, match, *args, **kwargs):
|
||||
for method in [posible for posible in dir(obj) if ('_'+match) in posible]:
|
||||
if getattr(obj,'debug',False): print (match+' is calling self.'+method)
|
||||
getattr(obj,method)(*args, **kwargs)
|
||||
def callHooks(match):
|
||||
"""
|
||||
Use this to wrap a funciton::
|
||||
|
||||
@callHooks('doEndIteration')
|
||||
def doEndIteration(self):
|
||||
pass
|
||||
|
||||
This will call everything named _doEndIteration* at the beginning of the function call.
|
||||
"""
|
||||
def callHooksWrap(f):
|
||||
@wraps(f)
|
||||
def wrapper(self,*args,**kwargs):
|
||||
|
||||
for method in [posible for posible in dir(self) if ('_'+match) in posible]:
|
||||
if getattr(self,'debug',False): print (match+' is calling self.'+method)
|
||||
getattr(self,method)(*args, **kwargs)
|
||||
|
||||
return f(self,*args,**kwargs)
|
||||
|
||||
extra = """
|
||||
If you have things that also need to run in the method %s, you can create a method::
|
||||
|
||||
def _%s*(self, ... ):
|
||||
pass
|
||||
|
||||
Where the * can be any string. If present, _%s* will be called at the start of the default %s call.
|
||||
You may also completely overwrite this function.
|
||||
""" % (match, match, match, match)
|
||||
|
||||
wrapper.__doc__ += extra
|
||||
return wrapper
|
||||
return callHooksWrap
|
||||
|
||||
|
||||
class Counter(object):
|
||||
|
||||
Reference in New Issue
Block a user