Merged in counters (pull request #27)

Simple profiling through decorator functions.
This commit is contained in:
rowanc1
2013-11-21 10:41:16 -08:00
5 changed files with 160 additions and 15 deletions
+9 -1
View File
@@ -1,5 +1,5 @@
import numpy as np
from SimPEG.utils import mkvc, sdiag
from SimPEG.utils import mkvc, sdiag, count, timeIt
import scipy.sparse as sp
norm = np.linalg.norm
@@ -37,6 +37,8 @@ class Problem(object):
to (locally) find how model parameters change the data, and optimize!
"""
counter = None
def __init__(self, mesh):
self.mesh = mesh
@@ -83,6 +85,7 @@ class Problem(object):
def dobs(self, value):
self._dobs = value
@count
def dpred(self, m, u=None):
"""
Predicted data.
@@ -94,6 +97,7 @@ class Problem(object):
u = self.field(m)
return self.P*u
@count
def dataResidual(self, m, u=None):
"""
:param numpy.array m: geophysical model
@@ -113,6 +117,7 @@ class Problem(object):
return self.dpred(m, u=u) - self.dobs
@timeIt
def J(self, m, v, u=None):
"""
:param numpy.array m: model
@@ -142,6 +147,7 @@ class Problem(object):
"""
raise NotImplementedError('J is not yet implemented.')
@timeIt
def Jt(self, m, v, u=None):
"""
:param numpy.array m: model
@@ -155,6 +161,7 @@ class Problem(object):
raise NotImplementedError('Jt is not yet implemented.')
@timeIt
def J_approx(self, m, v, u=None):
"""
@@ -169,6 +176,7 @@ class Problem(object):
"""
return self.J(m, v, u)
@timeIt
def Jt_approx(self, m, v, u=None):
"""
:param numpy.array m: model
+8 -3
View File
@@ -1,7 +1,7 @@
import numpy as np
import scipy.sparse as sp
import SimPEG
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt
from Optimize import Remember
from BetaSchedule import Cooling
@@ -13,6 +13,8 @@ class BaseInversion(object):
debug = False
beta0 = 1e4
counter = None
def __init__(self, prob, reg, opt, **kwargs):
setKwargs(self, **kwargs)
self.prob = prob
@@ -56,6 +58,7 @@ class BaseInversion(object):
def phi_d_target(self, value):
self._phi_d_target = value
@timeIt
def run(self, m0):
self.startup(m0)
while True:
@@ -131,7 +134,7 @@ class BaseInversion(object):
"""
printStoppers(self, self.stoppers)
@timeIt
def evalFunction(self, m, return_g=True, return_H=True):
u = self.prob.field(m)
@@ -162,7 +165,7 @@ class BaseInversion(object):
out += (operator,)
return out if len(out) > 1 else out[0]
@timeIt
def dataObj(self, m, u=None):
"""
:param numpy.array m: geophysical model
@@ -184,6 +187,7 @@ class BaseInversion(object):
R = mkvc(R)
return 0.5*np.vdot(R, R)
@timeIt
def dataObjDeriv(self, m, u=None):
"""
:param numpy.array m: geophysical model
@@ -224,6 +228,7 @@ class BaseInversion(object):
return dmisfit
@timeIt
def dataObj2Deriv(self, m, v, u=None):
"""
:param numpy.array m: geophysical model
+23 -3
View File
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt
norm = np.linalg.norm
import scipy.sparse as sp
from SimPEG import Solver
@@ -106,6 +106,8 @@ class Minimize(object):
debug = False
debugLS = False
counter = None
def __init__(self, **kwargs):
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
self.stoppers = [StoppingCriteria.tolerance_f, StoppingCriteria.moving_x, StoppingCriteria.tolerance_g, StoppingCriteria.norm_g, StoppingCriteria.iteration]
@@ -116,6 +118,7 @@ class Minimize(object):
setKwargs(self, **kwargs)
@timeIt
def minimize(self, evalFunction, x0):
"""
Minimizes the function (evalFunction) starting at the location x0.
@@ -263,6 +266,7 @@ class Minimize(object):
name = self.name if not inLS else self.nameLS
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
@count
def printIter(self, inLS=False):
"""
**printIter** is called directly after function evaluations.
@@ -289,14 +293,14 @@ class Minimize(object):
stoppers = self.stoppers if not inLS else self.stoppersLS
printStoppers(self, stoppers, pad='', stop=stop, done=done)
@timeIt
def stoppingCriteria(self, inLS=False):
if self._iter == 0:
self.f0 = self.f
self.g0 = self.g
return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
@timeIt
def projection(self, p):
"""
projects the search direction.
@@ -309,6 +313,7 @@ class Minimize(object):
"""
return p
@timeIt
def findSearchDirection(self):
"""
**findSearchDirection** should return an approximation of:
@@ -338,6 +343,7 @@ class Minimize(object):
"""
return -self.g
@count
def scaleSearchDirection(self, p):
"""
**scaleSearchDirection** should scale the search direction if appropriate.
@@ -355,6 +361,7 @@ class Minimize(object):
nameLS = "Armijo linesearch"
@timeIt
def modifySearchDirection(self, p):
"""
**modifySearchDirection** changes the search direction based on some sort of linesearch or trust-region criteria.
@@ -391,6 +398,7 @@ class Minimize(object):
return self._LS_xt, self._iterLS < self.maxIterLS
@count
def modifySearchDirectionBreak(self, p):
"""
Code is called if modifySearchDirection fails
@@ -411,6 +419,7 @@ class Minimize(object):
print 'The linesearch got broken. Boo.'
return p, False
@count
def doEndIteration(self, xt):
"""
**doEndIteration** is called at the end of each minimize iteration.
@@ -529,18 +538,22 @@ class ProjectedGradient(Minimize, Remember):
self.aSet_prev = self.activeSet(x0)
@count
def projection(self, x):
"""Make sure we are feasible."""
return np.median(np.c_[self.lower,x,self.upper],axis=1)
@count
def activeSet(self, x):
"""If we are on a bound"""
return np.logical_or(x == self.lower, x == self.upper)
@count
def inactiveSet(self, x):
"""The free variables."""
return np.logical_not(self.activeSet(x))
@count
def bindingSet(self, x):
"""
If we are on a bound and the negative gradient points away from the feasible set.
@@ -552,6 +565,7 @@ class ProjectedGradient(Minimize, Remember):
bind_low = np.logical_and(x == self.upper, self.g <= 0)
return np.logical_or(bind_up, bind_low)
@timeIt
def findSearchDirection(self):
self.aSet_prev = self.activeSet(self.xc)
allBoundsAreActive = sum(self.aSet_prev) == self.xc.size
@@ -592,6 +606,7 @@ class ProjectedGradient(Minimize, Remember):
# aSet_after = self.activeSet(self.xc+p)
return p
@timeIt
def _doEndIteration_ProjectedGradient(self, xt):
aSet = self.activeSet(xt)
bSet = self.bindingSet(xt)
@@ -622,6 +637,8 @@ class ProjectedGradient(Minimize, Remember):
class GaussNewton(Minimize, Remember):
name = 'Gauss Newton'
@timeIt
def findSearchDirection(self):
return Solver(self.H).solve(-self.g)
@@ -632,6 +649,7 @@ class InexactGaussNewton(Minimize, Remember):
maxIterCG = 10
tolCG = 1e-5
@timeIt
def findSearchDirection(self):
# TODO: use BFGS as a preconditioner or gauss sidel of the WtW or solve WtW directly
p, info = sp.linalg.cg(self.H, -self.g, tol=self.tolCG, maxiter=self.maxIterCG)
@@ -640,6 +658,8 @@ class InexactGaussNewton(Minimize, Remember):
class SteepestDescent(Minimize, Remember):
name = 'Steepest Descent'
@timeIt
def findSearchDirection(self):
return -self.g
+9 -8
View File
@@ -1,4 +1,4 @@
from SimPEG.utils import sdiag
from SimPEG.utils import sdiag, count, timeIt
import numpy as np
class Regularization(object):
@@ -40,21 +40,20 @@ class Regularization(object):
self._Wz = sdiag(a)*self.mesh.cellGradz
return self._Wz
alpha_s = 1e-6
alpha_x = 1
alpha_y = 1
alpha_z = 1
counter = None
def __init__(self, mesh):
self.mesh = mesh
self._Wx = None
self._Wy = None
self._Wz = None
self.alpha_s = 1e-6
self.alpha_x = 1
self.alpha_y = 1
self.alpha_z = 1
def pnorm(self, r):
return 0.5*r.dot(r)
@timeIt
def modelObj(self, m):
mresid = m - self.mref
@@ -69,6 +68,7 @@ class Regularization(object):
return mobj
@timeIt
def modelObjDeriv(self, m):
"""
@@ -104,6 +104,7 @@ class Regularization(object):
return mobjDeriv
@timeIt
def modelObj2Deriv(self, m):
mresid = m - self.mref
+111
View File
@@ -60,3 +60,114 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'):
r = stopper['right'](obj)
print pad + stopper['str'] % (l<=r,l,r)
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
import time
import numpy as np
class Counter(object):
"""
Counter allows anything that calls it to record iterations and
timings in a simple way.
Also has plotting functions that allow quick recalls of data.
If you want to use this, import *count* or *timeIt* and use them as decorators on class methods.
.. ::
class MyClass(object):
def __init__(self, url):
self.counter = Counter()
@count
def MyMethod(self):
pass
@timeIt
def MySecondMethod(self):
pass
c = MyClass('blah')
for i in range(100): c.MyMethod()
for i in range(300): c.MySecondMethod()
c.counter.summary()
"""
def __init__(self):
self._countList = {}
self._timeList = {}
def count(self, prop):
"""
Increases the count of the property.
"""
assert type(prop) is str, 'The property must be a string.'
if prop not in self._countList:
self._countList[prop] = 0
self._countList[prop] += 1
def countTic(self, prop):
"""
Times a property call, this is the init call.
"""
assert type(prop) is str, 'The property must be a string.'
if prop not in self._timeList:
self._timeList[prop] = []
self._timeList[prop].append(-time.time())
def countToc(self, prop):
"""
Times a property call, this is the end call.
"""
assert type(prop) is str, 'The property must be a string.'
assert prop in self._timeList, 'The property must already be in the dictionary.'
self._timeList[prop][-1] += time.time()
def summary(self):
"""
Provides a text summary of the current counters and timers.
"""
print 'Counters:'
for prop in sorted(self._countList):
print " {0:<40}: {1:8d}".format(prop,self._countList[prop])
print '\nTimes:'+' '*40+'mean sum'
for prop in sorted(self._timeList):
l = len(self._timeList[prop])
a = np.array(self._timeList[prop])
print " {0:<40}: {1:4.2e}, {2:4.2e}, {3:4d}x".format(prop,a.mean(),a.sum(),l)
def count(f):
def wrapper(self,*args,**kwargs):
counter = getattr(self,'counter',None)
if type(counter) is Counter: counter.count(self.__class__.__name__+'.'+f.__name__)
out = f(self,*args,**kwargs)
return out
return wrapper
def timeIt(f):
def wrapper(self,*args,**kwargs):
counter = getattr(self,'counter',None)
if type(counter) is Counter: counter.countTic(self.__class__.__name__+'.'+f.__name__)
out = f(self,*args,**kwargs)
if type(counter) is Counter: counter.countToc(self.__class__.__name__+'.'+f.__name__)
return out
return wrapper
if __name__ == '__main__':
class MyClass(object):
def __init__(self, url):
self.counter = Counter()
@count
def MyMethod(self):
pass
@timeIt
def MySecondMethod(self):
pass
c = MyClass('blah')
for i in range(100): c.MyMethod()
for i in range(300): c.MySecondMethod()
c.counter.summary()