diff --git a/SimPEG/inverse/Optimize.py b/SimPEG/inverse/Optimize.py index 01a6f453..d4c46c69 100644 --- a/SimPEG/inverse/Optimize.py +++ b/SimPEG/inverse/Optimize.py @@ -18,7 +18,7 @@ class Minimize(object): maxIter = 20 maxIterLS = 10 - maxStep = np.inf + maxStep = np.inf LSreduction = 1e-4 LSshorten = 0.5 tolF = 1e-1 @@ -56,14 +56,14 @@ class Minimize(object): self.printIter() if self.stoppingCriteria(): break p = self.findSearchDirection() - if self.maxStep < np.abs(p.max()): - p = self.maxStep*p/np.abs(p.max()) pub.sendMessage('Minimize.searchDirection', minimize=self, p=p) - xt, passLS = self.linesearch(p) ## TODO: should be called modifyStep to be inclusive of trust region stuff etc. - pub.sendMessage('Minimize.linesearch', minimize=self, xt=xt) + p = self.scaleSearchDirection(p) + pub.sendMessage('Minimize.scaleSearchDirection', minimize=self, p=p) + xt, passLS = self.modifySearchDirection(p) + pub.sendMessage('Minimize.modifySearchDirection', minimize=self, xt=xt) if not passLS: - xt = self.linesearchBreak(p) - return self.xc + xt, caught = self.modifySearchDirectionBreak(p) + if not caught: return self.xc self.doEndIteration(xt) pub.sendMessage('Minimize.endIteration', minimize=self, xt=xt) @@ -129,9 +129,6 @@ class Minimize(object): print "%d : iter = %3d\t <= maxIter\t = %3d" % (self._STOP[4], self._iter, self.maxIter) print "%s DONE! %s\n" % ('='*25,'='*25) - def findSearchDirection(self): - return -self.g - def stoppingCriteria(self): if self._iter == 0: self.fStop = self.f # Save this for stopping criteria @@ -147,7 +144,15 @@ class Minimize(object): def projection(self, p): return p - def linesearch(self, p): + def findSearchDirection(self): + return -self.g + + def scaleSearchDirection(self, p): + if self.maxStep < np.abs(p.max()): + p = self.maxStep*p/np.abs(p.max()) + return p + + def modifySearchDirection(self, p): # Armijo linesearch descent = np.inner(self.g, p) t = 1 @@ -163,8 +168,24 @@ class Minimize(object): self._iterLS = iterLS return xt, iterLS < self.maxIterLS - def linesearchBreak(self, p): + def modifySearchDirectionBreak(self, p): + """ + Code is called if modifySearchDirection fails + to find a descent direction. + + The search direction is passed as input and + this function must pass back both a new searchDirection, + and if the searchDirection break has been caught. + + By default, no additional work is done, and the + function returns a False indicating the break was not caught. + + :param numpy.ndarray p: searchDirection + :rtype: numpy.ndarray,bool + :return: (xt, breakCaught) + """ print 'The linesearch got broken. Boo.' + return p, False def doEndIteration(self, xt): # store old values @@ -199,9 +220,7 @@ if __name__ == '__main__': checkDerivative(Rosenbrock, x0, plotIt=False) def listener1(minimize,p): - plt.plot(p) - plt.show() - print p + print 'hi: ', p pub.subscribe(listener1, 'Minimize.searchDirection') xOpt = GaussNewton(maxIter=20,tolF=1e-10,tolX=1e-10,tolG=1e-10).minimize(Rosenbrock,x0)