mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-01 21:15:18 +08:00
callHooks generalizes some of the hook calling code in Optimize and Inversion
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import SimPEG
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt, callHooks
|
||||
from Optimize import Remember
|
||||
from BetaSchedule import Cooling
|
||||
|
||||
@@ -31,9 +31,9 @@ class BaseInversion(object):
|
||||
self.opt.printers.insert(3,SimPEG.inverse.IterationPrinters.phi_m)
|
||||
self.opt.stoppers.append(SimPEG.inverse.StoppingCriteria.phi_d_target_Minimize)
|
||||
|
||||
if not hasattr(opt, '_bfgsH0'): # Check if it has been set by the user and the default is not being used.
|
||||
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv.'
|
||||
opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv(),doDirect=True,options={'factorize':True}) # False, options={'M':'GS','maxIter':15}
|
||||
if not hasattr(opt, '_bfgsH0') and hasattr(opt, 'bfgsH0'): # Check if it has been set by the user and the default is not being used.
|
||||
print 'Setting bfgsH0 to the inverse of the modelObj2Deriv. Done using direct methods.'
|
||||
opt.bfgsH0 = SimPEG.Solver(reg.modelObj2Deriv())
|
||||
|
||||
|
||||
@property
|
||||
@@ -94,9 +94,7 @@ class BaseInversion(object):
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
for method in [posible for posible in dir(self) if '_startup' in posible]:
|
||||
if self.debug: print 'startup is calling self.'+method
|
||||
getattr(self,method)(m0)
|
||||
callHooks(self,'startup',m0)
|
||||
|
||||
if not hasattr(self.reg, '_mref'):
|
||||
print 'Regularization has not set mref. SimPEG will set it to m0.'
|
||||
@@ -124,9 +122,7 @@ class BaseInversion(object):
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
for method in [posible for posible in dir(self) if '_doEndIteration' in posible]:
|
||||
if self.debug: print 'doEndIteration is calling self.'+method
|
||||
getattr(self,method)()
|
||||
callHooks(self,'doEndIteration')
|
||||
|
||||
# store old values
|
||||
self.phi_d_last = self.phi_d
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt
|
||||
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt, callHooks
|
||||
norm = np.linalg.norm
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import Solver
|
||||
@@ -104,7 +104,6 @@ class Minimize(object):
|
||||
counter = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
|
||||
self.stoppers = [StoppingCriteria.tolerance_f, StoppingCriteria.moving_x, StoppingCriteria.tolerance_g, StoppingCriteria.norm_g, StoppingCriteria.iteration]
|
||||
self.stoppersLS = [StoppingCriteria.armijoGoldstein, StoppingCriteria.iterationLS]
|
||||
|
||||
@@ -208,9 +207,7 @@ class Minimize(object):
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
for method in [posible for posible in dir(self) if '_startup' in posible]:
|
||||
if self.debug: print 'startup is calling self.'+method
|
||||
getattr(self,method)(x0)
|
||||
callHooks(self,'startup',x0)
|
||||
|
||||
self._iter = 0
|
||||
self._iterLS = 0
|
||||
@@ -230,7 +227,6 @@ class Minimize(object):
|
||||
parent.printInit function and call that.
|
||||
|
||||
"""
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printInit', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
name = self.name if not inLS else self.nameLS
|
||||
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
|
||||
@@ -244,12 +240,8 @@ class Minimize(object):
|
||||
parent.printIter function and call that.
|
||||
|
||||
"""
|
||||
callHooks(self,'printIter',inLS)
|
||||
|
||||
for method in [posible for posible in dir(self) if '_printIter' in posible]:
|
||||
if self.debug: print 'printIter is calling self.'+method
|
||||
getattr(self,method)(inLS)
|
||||
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printIter', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
|
||||
|
||||
@@ -261,7 +253,6 @@ class Minimize(object):
|
||||
parent.printDone function and call that.
|
||||
|
||||
"""
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printDone', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
stop, done = (' STOP! ', ' DONE! ') if not inLS else ('----------------', ' End Linesearch ')
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
@@ -285,6 +276,7 @@ class Minimize(object):
|
||||
:rtype: numpy.ndarray
|
||||
:return: p, projected search direction
|
||||
"""
|
||||
callHooks(self,'projection',p)
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
@@ -415,9 +407,7 @@ class Minimize(object):
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
for method in [posible for posible in dir(self) if '_doEndIteration' in posible]:
|
||||
if self.debug: print 'doEndIteration is calling self.'+method
|
||||
getattr(self,method)(xt)
|
||||
callHooks(self,'doEndIteration',xt)
|
||||
|
||||
# store old values
|
||||
self.f_last = self.f
|
||||
|
||||
@@ -318,8 +318,7 @@ class DiffOperators(object):
|
||||
def cellGrady():
|
||||
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
|
||||
def fget(self):
|
||||
if self.dim < 2:
|
||||
return None
|
||||
if self.dim < 2: return None
|
||||
if getattr(self, '_cellGrady', None) is None:
|
||||
BC = ['neumann', 'neumann']
|
||||
n = self.n
|
||||
@@ -338,8 +337,7 @@ class DiffOperators(object):
|
||||
def cellGradz():
|
||||
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
|
||||
def fget(self):
|
||||
if self.dim < 3:
|
||||
return None
|
||||
if self.dim < 3: return None
|
||||
if getattr(self, '_cellGradz', None) is None:
|
||||
BC = ['neumann', 'neumann']
|
||||
n = self.n
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from SimPEG.utils import sdiag, count, timeIt
|
||||
from SimPEG.utils import sdiag, count, timeIt, setKwargs
|
||||
import numpy as np
|
||||
|
||||
class Regularization(object):
|
||||
@@ -41,15 +41,17 @@ class Regularization(object):
|
||||
return self._Wz
|
||||
|
||||
alpha_s = 1e-6
|
||||
alpha_x = 1
|
||||
alpha_y = 1
|
||||
alpha_z = 1
|
||||
alpha_x = 1.0
|
||||
alpha_y = 1.0
|
||||
alpha_z = 1.0
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh):
|
||||
def __init__(self, mesh, **kwargs):
|
||||
setKwargs(self, **kwargs)
|
||||
self.mesh = mesh
|
||||
|
||||
|
||||
def pnorm(self, r):
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
|
||||
@@ -61,6 +61,12 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'):
|
||||
print pad + stopper['str'] % (l<=r,l,r)
|
||||
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
|
||||
|
||||
def callHooks(obj, match, *args, **kwargs):
|
||||
for method in [posible for posible in dir(obj) if ('_'+match) in posible]:
|
||||
if getattr(obj,'debug',False): print (match+' is calling self.'+method)
|
||||
getattr(obj,method)(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
@@ -154,6 +160,8 @@ def timeIt(f):
|
||||
if type(counter) is Counter: counter.countToc(self.__class__.__name__+'.'+f.__name__)
|
||||
return out
|
||||
return wrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class MyClass(object):
|
||||
def __init__(self, url):
|
||||
|
||||
Reference in New Issue
Block a user