mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 19:08:03 +08:00
Issue #17 Scaled gradient. moved this to an overridable function. Changed linesearch to modifySearchDirection
This commit is contained in:
+34
-15
@@ -18,7 +18,7 @@ class Minimize(object):
|
||||
|
||||
maxIter = 20
|
||||
maxIterLS = 10
|
||||
maxStep = np.inf
|
||||
maxStep = np.inf
|
||||
LSreduction = 1e-4
|
||||
LSshorten = 0.5
|
||||
tolF = 1e-1
|
||||
@@ -56,14 +56,14 @@ class Minimize(object):
|
||||
self.printIter()
|
||||
if self.stoppingCriteria(): break
|
||||
p = self.findSearchDirection()
|
||||
if self.maxStep < np.abs(p.max()):
|
||||
p = self.maxStep*p/np.abs(p.max())
|
||||
pub.sendMessage('Minimize.searchDirection', minimize=self, p=p)
|
||||
xt, passLS = self.linesearch(p) ## TODO: should be called modifyStep to be inclusive of trust region stuff etc.
|
||||
pub.sendMessage('Minimize.linesearch', minimize=self, xt=xt)
|
||||
p = self.scaleSearchDirection(p)
|
||||
pub.sendMessage('Minimize.scaleSearchDirection', minimize=self, p=p)
|
||||
xt, passLS = self.modifySearchDirection(p)
|
||||
pub.sendMessage('Minimize.modifySearchDirection', minimize=self, xt=xt)
|
||||
if not passLS:
|
||||
xt = self.linesearchBreak(p)
|
||||
return self.xc
|
||||
xt, caught = self.modifySearchDirectionBreak(p)
|
||||
if not caught: return self.xc
|
||||
self.doEndIteration(xt)
|
||||
pub.sendMessage('Minimize.endIteration', minimize=self, xt=xt)
|
||||
|
||||
@@ -129,9 +129,6 @@ class Minimize(object):
|
||||
print "%d : iter = %3d\t <= maxIter\t = %3d" % (self._STOP[4], self._iter, self.maxIter)
|
||||
print "%s DONE! %s\n" % ('='*25,'='*25)
|
||||
|
||||
def findSearchDirection(self):
|
||||
return -self.g
|
||||
|
||||
def stoppingCriteria(self):
|
||||
if self._iter == 0:
|
||||
self.fStop = self.f # Save this for stopping criteria
|
||||
@@ -147,7 +144,15 @@ class Minimize(object):
|
||||
def projection(self, p):
|
||||
return p
|
||||
|
||||
def linesearch(self, p):
|
||||
def findSearchDirection(self):
|
||||
return -self.g
|
||||
|
||||
def scaleSearchDirection(self, p):
|
||||
if self.maxStep < np.abs(p.max()):
|
||||
p = self.maxStep*p/np.abs(p.max())
|
||||
return p
|
||||
|
||||
def modifySearchDirection(self, p):
|
||||
# Armijo linesearch
|
||||
descent = np.inner(self.g, p)
|
||||
t = 1
|
||||
@@ -163,8 +168,24 @@ class Minimize(object):
|
||||
self._iterLS = iterLS
|
||||
return xt, iterLS < self.maxIterLS
|
||||
|
||||
def linesearchBreak(self, p):
|
||||
def modifySearchDirectionBreak(self, p):
|
||||
"""
|
||||
Code is called if modifySearchDirection fails
|
||||
to find a descent direction.
|
||||
|
||||
The search direction is passed as input and
|
||||
this function must pass back both a new searchDirection,
|
||||
and if the searchDirection break has been caught.
|
||||
|
||||
By default, no additional work is done, and the
|
||||
function returns a False indicating the break was not caught.
|
||||
|
||||
:param numpy.ndarray p: searchDirection
|
||||
:rtype: numpy.ndarray,bool
|
||||
:return: (xt, breakCaught)
|
||||
"""
|
||||
print 'The linesearch got broken. Boo.'
|
||||
return p, False
|
||||
|
||||
def doEndIteration(self, xt):
|
||||
# store old values
|
||||
@@ -199,9 +220,7 @@ if __name__ == '__main__':
|
||||
checkDerivative(Rosenbrock, x0, plotIt=False)
|
||||
|
||||
def listener1(minimize,p):
|
||||
plt.plot(p)
|
||||
plt.show()
|
||||
print p
|
||||
print 'hi: ', p
|
||||
pub.subscribe(listener1, 'Minimize.searchDirection')
|
||||
|
||||
xOpt = GaussNewton(maxIter=20,tolF=1e-10,tolX=1e-10,tolG=1e-10).minimize(Rosenbrock,x0)
|
||||
|
||||
Reference in New Issue
Block a user