Issue #17 Scaled gradient. moved this to an overridable function. Changed linesearch to modifySearchDirection

This commit is contained in:
Rowan Cockett
2013-11-04 09:43:49 -08:00
parent 07c917dbae
commit 224b0cea9e
+34 -15
View File
@@ -18,7 +18,7 @@ class Minimize(object):
maxIter = 20
maxIterLS = 10
maxStep = np.inf
maxStep = np.inf
LSreduction = 1e-4
LSshorten = 0.5
tolF = 1e-1
@@ -56,14 +56,14 @@ class Minimize(object):
self.printIter()
if self.stoppingCriteria(): break
p = self.findSearchDirection()
if self.maxStep < np.abs(p.max()):
p = self.maxStep*p/np.abs(p.max())
pub.sendMessage('Minimize.searchDirection', minimize=self, p=p)
xt, passLS = self.linesearch(p) ## TODO: should be called modifyStep to be inclusive of trust region stuff etc.
pub.sendMessage('Minimize.linesearch', minimize=self, xt=xt)
p = self.scaleSearchDirection(p)
pub.sendMessage('Minimize.scaleSearchDirection', minimize=self, p=p)
xt, passLS = self.modifySearchDirection(p)
pub.sendMessage('Minimize.modifySearchDirection', minimize=self, xt=xt)
if not passLS:
xt = self.linesearchBreak(p)
return self.xc
xt, caught = self.modifySearchDirectionBreak(p)
if not caught: return self.xc
self.doEndIteration(xt)
pub.sendMessage('Minimize.endIteration', minimize=self, xt=xt)
@@ -129,9 +129,6 @@ class Minimize(object):
print "%d : iter = %3d\t <= maxIter\t = %3d" % (self._STOP[4], self._iter, self.maxIter)
print "%s DONE! %s\n" % ('='*25,'='*25)
def findSearchDirection(self):
return -self.g
def stoppingCriteria(self):
if self._iter == 0:
self.fStop = self.f # Save this for stopping criteria
@@ -147,7 +144,15 @@ class Minimize(object):
def projection(self, p):
return p
def linesearch(self, p):
def findSearchDirection(self):
return -self.g
def scaleSearchDirection(self, p):
if self.maxStep < np.abs(p.max()):
p = self.maxStep*p/np.abs(p.max())
return p
def modifySearchDirection(self, p):
# Armijo linesearch
descent = np.inner(self.g, p)
t = 1
@@ -163,8 +168,24 @@ class Minimize(object):
self._iterLS = iterLS
return xt, iterLS < self.maxIterLS
def linesearchBreak(self, p):
def modifySearchDirectionBreak(self, p):
"""
Code is called if modifySearchDirection fails
to find a descent direction.
The search direction is passed as input and
this function must pass back both a new searchDirection,
and if the searchDirection break has been caught.
By default, no additional work is done, and the
function returns a False indicating the break was not caught.
:param numpy.ndarray p: searchDirection
:rtype: numpy.ndarray,bool
:return: (xt, breakCaught)
"""
print 'The linesearch got broken. Boo.'
return p, False
def doEndIteration(self, xt):
# store old values
@@ -199,9 +220,7 @@ if __name__ == '__main__':
checkDerivative(Rosenbrock, x0, plotIt=False)
def listener1(minimize,p):
plt.plot(p)
plt.show()
print p
print 'hi: ', p
pub.subscribe(listener1, 'Minimize.searchDirection')
xOpt = GaussNewton(maxIter=20,tolF=1e-10,tolX=1e-10,tolG=1e-10).minimize(Rosenbrock,x0)