mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-29 00:23:53 +08:00
Stopping Criteria and Printers generalized in Optimize.
This commit is contained in:
@@ -9,11 +9,27 @@ class Inversion(object):
|
||||
name = 'SimPEG Inversion'
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
self.setKwargs(**kwargs)
|
||||
self.prob = prob
|
||||
self.reg = reg
|
||||
self.opt = opt
|
||||
self.opt.parent = self
|
||||
self.setKwargs(**kwargs)
|
||||
|
||||
# Check if we have inserted printers into the optimization
|
||||
haveInserted = False
|
||||
for printer in self.opt.printers:
|
||||
haveInserted = haveInserted or printer["title"] == 'phi_d'
|
||||
|
||||
if not haveInserted:
|
||||
self.opt.printers.insert(1,{"title": "beta",
|
||||
"value": lambda M: M.parent._beta,
|
||||
"width": 13, "format": "%1.2e"})
|
||||
self.opt.printers.insert(2,{"title": "phi_d",
|
||||
"value": lambda M: M.parent._phi_d_last,
|
||||
"width": 13, "format": "%1.2e"})
|
||||
self.opt.printers.insert(3,{"title": "phi_m",
|
||||
"value": lambda M: M.parent._phi_m_last,
|
||||
"width": 13, "format": "%1.2e"})
|
||||
|
||||
def setKwargs(self, **kwargs):
|
||||
"""Sets key word arguments (kwargs) that are present in the object, throw an error if they don't exist."""
|
||||
@@ -23,13 +39,13 @@ class Inversion(object):
|
||||
else:
|
||||
raise Exception('%s attr is not recognized' % attr)
|
||||
|
||||
def printInit(self):
|
||||
print "%s %s %s" % ('='*22, self.name, '='*22)
|
||||
print " # beta phi_d phi_m f norm(dJ) #LS"
|
||||
print "%s" % '-'*62
|
||||
# def printInit(self):
|
||||
# print "%s %s %s" % ('='*22, self.name, '='*22)
|
||||
# print " # beta phi_d phi_m f norm(dJ) #LS"
|
||||
# print "%s" % '-'*62
|
||||
|
||||
def printIter(self):
|
||||
print "%3d %1.2e %1.2e %1.2e %1.2e %1.2e %3d" % (self.opt._iter, self._beta, self._phi_d_last, self._phi_m_last, self.opt.f, np.linalg.norm(self.opt.g), self.opt._iterLS)
|
||||
# def printIter(self):
|
||||
# print "%3d %1.2e %1.2e %1.2e %1.2e %1.2e %3d" % (self.opt._iter, self._beta, self._phi_d_last, self._phi_m_last, self.opt.f, np.linalg.norm(self.opt.g), self.opt._iterLS)
|
||||
|
||||
@property
|
||||
def Wd(self):
|
||||
|
||||
+94
-50
@@ -33,17 +33,18 @@ class Minimize(object):
|
||||
tolX = 1e-1
|
||||
tolG = 1e-1
|
||||
eps = 1e-5
|
||||
debug = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
|
||||
self.stoppers = [{
|
||||
"str": "%d : |fc-fOld| = %1.4e <= tolF*(1+|f0|) = %1.4e",
|
||||
"left": lambda M: 1 if M._iter==0 else abs(M.f-M.fOld),
|
||||
"left": lambda M: 1 if M._iter==0 else abs(M.f-M.f_last),
|
||||
"right": lambda M: 0 if M._iter==0 else M.tolF*(1+abs(M.f0)),
|
||||
"stopType": "optimal"
|
||||
},{
|
||||
"str": "%d : |xc-xOld| = %1.4e <= tolX*(1+|x0|) = %1.4e",
|
||||
"left": lambda M: 1 if M._iter==0 else norm(M.xc-M.xOld),
|
||||
"str": "%d : |xc-x_last| = %1.4e <= tolX*(1+|x0|) = %1.4e",
|
||||
"left": lambda M: 1 if M._iter==0 else norm(M.xc-M.x_last),
|
||||
"right": lambda M: 0 if M._iter==0 else M.tolX*(1+norm(M.x0)),
|
||||
"stopType": "optimal"
|
||||
},{
|
||||
@@ -57,17 +58,28 @@ class Minimize(object):
|
||||
"right": lambda M: 1e3*M.eps,
|
||||
"stopType": "critical"
|
||||
},{
|
||||
"str": "%d : maxIter = %3d\t <= iter\t = %3d",
|
||||
"str": "%d : maxIter = %3d <= iter = %3d",
|
||||
"left": lambda M: M.maxIter,
|
||||
"right": lambda M: M._iter,
|
||||
"stopType": "critical"
|
||||
}]
|
||||
# print "%3d\t%1.2e\t%1.2e\t%d" % (self._iter, self.f, norm(self.g), self._iterLS)
|
||||
|
||||
self.stoppersLS = [{
|
||||
"str": "%d : ft = %1.4e <= alp*descent = %1.4e",
|
||||
"left": lambda M: M._LS_ft,
|
||||
"right": lambda M: M.f + self.LSreduction * M._LS_descent,
|
||||
"stopType": "optimal"
|
||||
},{
|
||||
"str": "%d : maxIterLS = %3d <= iterLS = %3d",
|
||||
"left": lambda M: M.maxIterLS,
|
||||
"right": lambda M: M._iterLS,
|
||||
"stopType": "critical"
|
||||
}]
|
||||
|
||||
self.printers = [{
|
||||
"title": "#",
|
||||
"value": lambda M: M._iter,
|
||||
"width": 5,
|
||||
"width": 10,
|
||||
"format": "%3d"
|
||||
},{
|
||||
"title": "f",
|
||||
@@ -85,6 +97,29 @@ class Minimize(object):
|
||||
"width": 5,
|
||||
"format": "%d"
|
||||
}]
|
||||
|
||||
self.printersLS = [{
|
||||
"title": "#",
|
||||
"value": lambda M: (M._iter, M._iterLS),
|
||||
"width": 10,
|
||||
"format": "%3d.%d"
|
||||
},{
|
||||
"title": "t",
|
||||
"value": lambda M: M._LS_t,
|
||||
"width": 14,
|
||||
"format": "%0.5f"
|
||||
},{
|
||||
"title": "ft",
|
||||
"value": lambda M: M._LS_ft,
|
||||
"width": 14,
|
||||
"format": "%1.2e"
|
||||
},{
|
||||
"title": "f + alp*g.T*p",
|
||||
"value": lambda M: M.f + M.LSreduction*M._LS_descent,
|
||||
"width": 16,
|
||||
"format": "%1.2e"
|
||||
}]
|
||||
|
||||
self.setKwargs(**kwargs)
|
||||
|
||||
def setKwargs(self, **kwargs):
|
||||
@@ -157,7 +192,6 @@ class Minimize(object):
|
||||
while True:
|
||||
self.f, self.g, self.H = evalFunction(self.xc, return_g=True, return_H=True)
|
||||
if doPub: pub.sendMessage('Minimize.evalFunction', minimize=self, f=self.f, g=self.g, H=self.H)
|
||||
self.printIter()
|
||||
if self.stoppingCriteria(): break
|
||||
p = self.findSearchDirection()
|
||||
if doPub: pub.sendMessage('Minimize.searchDirection', minimize=self, p=p)
|
||||
@@ -170,6 +204,7 @@ class Minimize(object):
|
||||
if not caught: return self.xc
|
||||
self.doEndIteration(xt)
|
||||
if doPub: pub.sendMessage('Minimize.endIteration', minimize=self, xt=xt)
|
||||
self.printIter()
|
||||
|
||||
self.printDone()
|
||||
|
||||
@@ -215,10 +250,11 @@ class Minimize(object):
|
||||
x0 = self.projection(x0) # ensure that we start of feasible.
|
||||
self.x0 = x0
|
||||
self.xc = x0
|
||||
self.xOld = x0
|
||||
self.f_last = np.nan
|
||||
self.x_last = x0
|
||||
|
||||
|
||||
def printInit(self):
|
||||
def printInit(self, inLS=False):
|
||||
"""
|
||||
**printInit** is called at the beginning of the optimization routine.
|
||||
|
||||
@@ -226,20 +262,21 @@ class Minimize(object):
|
||||
parent.printInit function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printInit', minimize=self)
|
||||
if self.parent is not None and hasattr(self.parent, 'printInit'):
|
||||
self.parent.printInit()
|
||||
return
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printInit', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
|
||||
printers = self.printers if not inLS else self.printersLS
|
||||
name = self.name if not inLS else self.nameLS
|
||||
titles = ''
|
||||
widths = 0
|
||||
for printer in self.printers:
|
||||
for printer in printers:
|
||||
titles += ('{:^%i}'%printer['width']).format(printer['title']) + ''
|
||||
widths += printer['width']
|
||||
print "{0} {1} {0}".format('='*((widths-1-len(self.name))/2), self.name)
|
||||
print titles
|
||||
print "%s" % '-'*widths
|
||||
print pad + "{0} {1} {0}".format('='*((widths-1-len(name))/2), name)
|
||||
print pad + titles
|
||||
print pad + "%s" % '-'*widths
|
||||
|
||||
def printIter(self):
|
||||
def printIter(self, inLS=False):
|
||||
"""
|
||||
**printIter** is called directly after function evaluations.
|
||||
|
||||
@@ -247,18 +284,17 @@ class Minimize(object):
|
||||
parent.printIter function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printIter', minimize=self)
|
||||
if self.parent is not None and hasattr(self.parent, 'printIter'):
|
||||
self.parent.printIter()
|
||||
return
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printIter', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
|
||||
printers = self.printers if not inLS else self.printersLS
|
||||
values = ''
|
||||
for printer in self.printers:
|
||||
for printer in printers:
|
||||
values += ('{:^%i}'%printer['width']).format(printer['format'] % printer['value'](self))
|
||||
print values
|
||||
# print "%3d\t%1.2e\t%1.2e\t%d" % (self._iter, self.f, norm(self.g), self._iterLS)
|
||||
print pad + values
|
||||
# print pad + "%3d\t%1.2e\t%1.2e\t%d" % (self._iter, self.f, norm(self.g), self._iterLS)
|
||||
|
||||
def printDone(self):
|
||||
def printDone(self, inLS=False):
|
||||
"""
|
||||
**printDone** is called at the end of the optimization routine.
|
||||
|
||||
@@ -266,19 +302,21 @@ class Minimize(object):
|
||||
parent.printDone function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printDone', minimize=self)
|
||||
print "%s STOP! %s" % ('-'*25,'-'*25)
|
||||
# TODO: put controls on gradient value, min model update, and function value
|
||||
for stopper in self.stoppers:
|
||||
if doPub and not inLS: pub.sendMessage('Minimize.printDone', minimize=self)
|
||||
pad = ' '*10 if inLS else ''
|
||||
stop, done = (' STOP! ', ' DONE! ') if not inLS else ('----------------', ' End Linesearch ')
|
||||
print pad + "%s%s%s" % ('-'*25,stop,'-'*25)
|
||||
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
for stopper in stoppers:
|
||||
l = stopper['left'](self)
|
||||
r = stopper['right'](self)
|
||||
print stopper['str'] % (l<=r,l,r)
|
||||
print pad + stopper['str'] % (l<=r,l,r)
|
||||
|
||||
print "%s DONE! %s\n" % ('='*25,'='*25)
|
||||
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
|
||||
|
||||
if self.parent is not None and hasattr(self.parent, 'printDone'): self.parent.printDone()
|
||||
|
||||
def stoppingCriteria(self):
|
||||
def stoppingCriteria(self, inLS=False):
|
||||
if self._iter == 0:
|
||||
# Save this for stopping criteria
|
||||
self.f0 = self.f
|
||||
@@ -287,7 +325,9 @@ class Minimize(object):
|
||||
# check stopping rules
|
||||
optimal = []
|
||||
critical = []
|
||||
for stopper in self.stoppers:
|
||||
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
for stopper in stoppers:
|
||||
l = stopper['left'](self)
|
||||
r = stopper['right'](self)
|
||||
if stopper['stopType'] == 'optimal':
|
||||
@@ -352,6 +392,8 @@ class Minimize(object):
|
||||
p = self.maxStep*p/np.abs(p.max())
|
||||
return p
|
||||
|
||||
nameLS = "Armijo linesearch"
|
||||
|
||||
def modifySearchDirection(self, p):
|
||||
"""
|
||||
**modifySearchDirection** changes the search direction based on some sort of linesearch or trust-region criteria.
|
||||
@@ -371,20 +413,22 @@ class Minimize(object):
|
||||
:return: (xt, passLS)
|
||||
"""
|
||||
# Projected Armijo linesearch
|
||||
t = 1
|
||||
iterLS = 0
|
||||
while iterLS < self.maxIterLS:
|
||||
xt = self.projection(self.xc + t*p)
|
||||
ft = self.evalFunction(xt, return_g=False, return_H=False)
|
||||
descent = np.inner(self.g, xt - self.xc) # this takes into account multiplying by t, but is important for projection.
|
||||
if ft < self.f + t*self.LSreduction*descent:
|
||||
break
|
||||
iterLS += 1
|
||||
t = self.LSshorten*t
|
||||
# TODO: Check if t is tooo small.
|
||||
self._LS_t = 1
|
||||
self._iterLS = 0
|
||||
while self._iterLS < self.maxIterLS:
|
||||
self._LS_xt = self.projection(self.xc + self._LS_t*p)
|
||||
self._LS_ft = self.evalFunction(self._LS_xt, return_g=False, return_H=False)[0]
|
||||
self._LS_descent = np.inner(self.g, self._LS_xt - self.xc) # this takes into account multiplying by t, but is important for projection.
|
||||
if self.stoppingCriteria(inLS=True): break
|
||||
self._iterLS += 1
|
||||
self._LS_t = self.LSshorten*self._LS_t
|
||||
if self.debug:
|
||||
if self._iterLS == 1: self.printInit(inLS=True)
|
||||
self.printIter(inLS=True)
|
||||
|
||||
self._iterLS = iterLS
|
||||
return xt, iterLS < self.maxIterLS
|
||||
if self.debug and self._iterLS > 0: self.printDone(inLS=True)
|
||||
|
||||
return self._LS_xt, self._iterLS < self.maxIterLS
|
||||
|
||||
def modifySearchDirectionBreak(self, p):
|
||||
"""
|
||||
@@ -431,8 +475,8 @@ class Minimize(object):
|
||||
|
||||
|
||||
# store old values
|
||||
self.fOld = self.f
|
||||
self.xOld, self.xc = self.xc, xt
|
||||
self.f_last = self.f
|
||||
self.x_last, self.xc = self.xc, xt
|
||||
self._iter += 1
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user