callHooks generalizes some of the hook calling code in Optimize and Inversion

This commit is contained in:
Rowan Cockett
2013-11-21 16:56:01 -08:00
parent 0bd971c22f
commit dbaea1fda9
5 changed files with 28 additions and 34 deletions
+6 -10
View File
@@ -1,7 +1,7 @@
import numpy as np
import scipy.sparse as sp
import SimPEG
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt, callHooks
from Optimize import Remember
from BetaSchedule import Cooling
@@ -31,9 +31,9 @@ class BaseInversion(object):
self.opt.printers.insert(3,SimPEG.inverse.IterationPrinters.phi_m)
self.opt.stoppers.append(SimPEG.inverse.StoppingCriteria.phi_d_target_Minimize)
if not hasattr(opt, '_bfgsH0'): # Check if it has been set by the user and the default is not being used.
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv.'
opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv(),doDirect=True,options={'factorize':True}) # False, options={'M':'GS','maxIter':15}
if not hasattr(opt, '_bfgsH0') and hasattr(opt, 'bfgsH0'): # Check if it has been set by the user and the default is not being used.
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv. Done using direct methods.'
opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv())
@property
@@ -94,9 +94,7 @@ class BaseInversion(object):
:rtype: None
:return: None
"""
for method in [posible for posible in dir(self) if '_startup' in posible]:
if self.debug: print 'startup is calling self.'+method
getattr(self,method)(m0)
callHooks(self,'startup',m0)
if not hasattr(self.reg, '_mref'):
print 'Regularization has not set mref. SimPEG will set it to m0.'
@@ -124,9 +122,7 @@ class BaseInversion(object):
:rtype: None
:return: None
"""
for method in [posible for posible in dir(self) if '_doEndIteration' in posible]:
if self.debug: print 'doEndIteration is calling self.'+method
getattr(self,method)()
callHooks(self,'doEndIteration')
# store old values
self.phi_d_last = self.phi_d
+5 -15
View File
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt, callHooks
norm = np.linalg.norm
import scipy.sparse as sp
from SimPEG import Solver
@@ -104,7 +104,6 @@ class Minimize(object):
counter = None
def __init__(self, **kwargs):
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
self.stoppers = [StoppingCriteria.tolerance_f, StoppingCriteria.moving_x, StoppingCriteria.tolerance_g, StoppingCriteria.norm_g, StoppingCriteria.iteration]
self.stoppersLS = [StoppingCriteria.armijoGoldstein, StoppingCriteria.iterationLS]
@@ -208,9 +207,7 @@ class Minimize(object):
:rtype: None
:return: None
"""
for method in [posible for posible in dir(self) if '_startup' in posible]:
if self.debug: print 'startup is calling self.'+method
getattr(self,method)(x0)
callHooks(self,'startup',x0)
self._iter = 0
self._iterLS = 0
@@ -230,7 +227,6 @@ class Minimize(object):
parent.printInit function and call that.
"""
if doPub and not inLS: pub.sendMessage('Minimize.printInit', minimize=self)
pad = ' '*10 if inLS else ''
name = self.name if not inLS else self.nameLS
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
@@ -244,12 +240,8 @@ class Minimize(object):
parent.printIter function and call that.
"""
callHooks(self,'printIter',inLS)
for method in [posible for posible in dir(self) if '_printIter' in posible]:
if self.debug: print 'printIter is calling self.'+method
getattr(self,method)(inLS)
if doPub and not inLS: pub.sendMessage('Minimize.printIter', minimize=self)
pad = ' '*10 if inLS else ''
printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
@@ -261,7 +253,6 @@ class Minimize(object):
parent.printDone function and call that.
"""
if doPub and not inLS: pub.sendMessage('Minimize.printDone', minimize=self)
pad = ' '*10 if inLS else ''
stop, done = (' STOP! ', ' DONE! ') if not inLS else ('----------------', ' End Linesearch ')
stoppers = self.stoppers if not inLS else self.stoppersLS
@@ -285,6 +276,7 @@ class Minimize(object):
:rtype: numpy.ndarray
:return: p, projected search direction
"""
callHooks(self,'projection',p)
return p
@timeIt
@@ -415,9 +407,7 @@ class Minimize(object):
:rtype: None
:return: None
"""
for method in [posible for posible in dir(self) if '_doEndIteration' in posible]:
if self.debug: print 'doEndIteration is calling self.'+method
getattr(self,method)(xt)
callHooks(self,'doEndIteration',xt)
# store old values
self.f_last = self.f
+2 -4
View File
@@ -318,8 +318,7 @@ class DiffOperators(object):
def cellGrady():
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
def fget(self):
if self.dim < 2:
return None
if self.dim < 2: return None
if getattr(self, '_cellGrady', None) is None:
BC = ['neumann', 'neumann']
n = self.n
@@ -338,8 +337,7 @@ class DiffOperators(object):
def cellGradz():
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
def fget(self):
if self.dim < 3:
return None
if self.dim < 3: return None
if getattr(self, '_cellGradz', None) is None:
BC = ['neumann', 'neumann']
n = self.n
+7 -5
View File
@@ -1,4 +1,4 @@
from SimPEG.utils import sdiag, count, timeIt
from SimPEG.utils import sdiag, count, timeIt, setKwargs
import numpy as np
class Regularization(object):
@@ -41,15 +41,17 @@ class Regularization(object):
return self._Wz
alpha_s = 1e-6
alpha_x = 1
alpha_y = 1
alpha_z = 1
alpha_x = 1.0
alpha_y = 1.0
alpha_z = 1.0
counter = None
def __init__(self, mesh):
def __init__(self, mesh, **kwargs):
setKwargs(self, **kwargs)
self.mesh = mesh
def pnorm(self, r):
return 0.5*r.dot(r)
+8
View File
@@ -61,6 +61,12 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'):
print pad + stopper['str'] % (l<=r,l,r)
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
def callHooks(obj, match, *args, **kwargs):
for method in [posible for posible in dir(obj) if ('_'+match) in posible]:
if getattr(obj,'debug',False): print (match+' is calling self.'+method)
getattr(obj,method)(*args, **kwargs)
import time
import numpy as np
@@ -154,6 +160,8 @@ def timeIt(f):
if type(counter) is Counter: counter.countToc(self.__class__.__name__+'.'+f.__name__)
return out
return wrapper
if __name__ == '__main__':
class MyClass(object):
def __init__(self, url):