mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 04:19:45 +08:00
Merged in counters (pull request #27)
Simple profiling through decorator functions.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from SimPEG.utils import mkvc, sdiag
|
||||
from SimPEG.utils import mkvc, sdiag, count, timeIt
|
||||
import scipy.sparse as sp
|
||||
norm = np.linalg.norm
|
||||
|
||||
@@ -37,6 +37,8 @@ class Problem(object):
|
||||
to (locally) find how model parameters change the data, and optimize!
|
||||
"""
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh):
|
||||
self.mesh = mesh
|
||||
|
||||
@@ -83,6 +85,7 @@ class Problem(object):
|
||||
def dobs(self, value):
|
||||
self._dobs = value
|
||||
|
||||
@count
|
||||
def dpred(self, m, u=None):
|
||||
"""
|
||||
Predicted data.
|
||||
@@ -94,6 +97,7 @@ class Problem(object):
|
||||
u = self.field(m)
|
||||
return self.P*u
|
||||
|
||||
@count
|
||||
def dataResidual(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
@@ -113,6 +117,7 @@ class Problem(object):
|
||||
|
||||
return self.dpred(m, u=u) - self.dobs
|
||||
|
||||
@timeIt
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -142,6 +147,7 @@ class Problem(object):
|
||||
"""
|
||||
raise NotImplementedError('J is not yet implemented.')
|
||||
|
||||
@timeIt
|
||||
def Jt(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -155,6 +161,7 @@ class Problem(object):
|
||||
raise NotImplementedError('Jt is not yet implemented.')
|
||||
|
||||
|
||||
@timeIt
|
||||
def J_approx(self, m, v, u=None):
|
||||
"""
|
||||
|
||||
@@ -169,6 +176,7 @@ class Problem(object):
|
||||
"""
|
||||
return self.J(m, v, u)
|
||||
|
||||
@timeIt
|
||||
def Jt_approx(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import SimPEG
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt
|
||||
from Optimize import Remember
|
||||
from BetaSchedule import Cooling
|
||||
|
||||
@@ -13,6 +13,8 @@ class BaseInversion(object):
|
||||
debug = False
|
||||
beta0 = 1e4
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
setKwargs(self, **kwargs)
|
||||
self.prob = prob
|
||||
@@ -56,6 +58,7 @@ class BaseInversion(object):
|
||||
def phi_d_target(self, value):
|
||||
self._phi_d_target = value
|
||||
|
||||
@timeIt
|
||||
def run(self, m0):
|
||||
self.startup(m0)
|
||||
while True:
|
||||
@@ -131,7 +134,7 @@ class BaseInversion(object):
|
||||
"""
|
||||
printStoppers(self, self.stoppers)
|
||||
|
||||
|
||||
@timeIt
|
||||
def evalFunction(self, m, return_g=True, return_H=True):
|
||||
|
||||
u = self.prob.field(m)
|
||||
@@ -162,7 +165,7 @@ class BaseInversion(object):
|
||||
out += (operator,)
|
||||
return out if len(out) > 1 else out[0]
|
||||
|
||||
|
||||
@timeIt
|
||||
def dataObj(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
@@ -184,6 +187,7 @@ class BaseInversion(object):
|
||||
R = mkvc(R)
|
||||
return 0.5*np.vdot(R, R)
|
||||
|
||||
@timeIt
|
||||
def dataObjDeriv(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
@@ -224,6 +228,7 @@ class BaseInversion(object):
|
||||
|
||||
return dmisfit
|
||||
|
||||
@timeIt
|
||||
def dataObj2Deriv(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers
|
||||
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt
|
||||
norm = np.linalg.norm
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import Solver
|
||||
@@ -106,6 +106,8 @@ class Minimize(object):
|
||||
debug = False
|
||||
debugLS = False
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
|
||||
self.stoppers = [StoppingCriteria.tolerance_f, StoppingCriteria.moving_x, StoppingCriteria.tolerance_g, StoppingCriteria.norm_g, StoppingCriteria.iteration]
|
||||
@@ -116,6 +118,7 @@ class Minimize(object):
|
||||
|
||||
setKwargs(self, **kwargs)
|
||||
|
||||
@timeIt
|
||||
def minimize(self, evalFunction, x0):
|
||||
"""
|
||||
Minimizes the function (evalFunction) starting at the location x0.
|
||||
@@ -263,6 +266,7 @@ class Minimize(object):
|
||||
name = self.name if not inLS else self.nameLS
|
||||
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
|
||||
|
||||
@count
|
||||
def printIter(self, inLS=False):
|
||||
"""
|
||||
**printIter** is called directly after function evaluations.
|
||||
@@ -289,14 +293,14 @@ class Minimize(object):
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
printStoppers(self, stoppers, pad='', stop=stop, done=done)
|
||||
|
||||
|
||||
@timeIt
|
||||
def stoppingCriteria(self, inLS=False):
|
||||
if self._iter == 0:
|
||||
self.f0 = self.f
|
||||
self.g0 = self.g
|
||||
return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
|
||||
|
||||
|
||||
@timeIt
|
||||
def projection(self, p):
|
||||
"""
|
||||
projects the search direction.
|
||||
@@ -309,6 +313,7 @@ class Minimize(object):
|
||||
"""
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
def findSearchDirection(self):
|
||||
"""
|
||||
**findSearchDirection** should return an approximation of:
|
||||
@@ -338,6 +343,7 @@ class Minimize(object):
|
||||
"""
|
||||
return -self.g
|
||||
|
||||
@count
|
||||
def scaleSearchDirection(self, p):
|
||||
"""
|
||||
**scaleSearchDirection** should scale the search direction if appropriate.
|
||||
@@ -355,6 +361,7 @@ class Minimize(object):
|
||||
|
||||
nameLS = "Armijo linesearch"
|
||||
|
||||
@timeIt
|
||||
def modifySearchDirection(self, p):
|
||||
"""
|
||||
**modifySearchDirection** changes the search direction based on some sort of linesearch or trust-region criteria.
|
||||
@@ -391,6 +398,7 @@ class Minimize(object):
|
||||
|
||||
return self._LS_xt, self._iterLS < self.maxIterLS
|
||||
|
||||
@count
|
||||
def modifySearchDirectionBreak(self, p):
|
||||
"""
|
||||
Code is called if modifySearchDirection fails
|
||||
@@ -411,6 +419,7 @@ class Minimize(object):
|
||||
print 'The linesearch got broken. Boo.'
|
||||
return p, False
|
||||
|
||||
@count
|
||||
def doEndIteration(self, xt):
|
||||
"""
|
||||
**doEndIteration** is called at the end of each minimize iteration.
|
||||
@@ -529,18 +538,22 @@ class ProjectedGradient(Minimize, Remember):
|
||||
|
||||
self.aSet_prev = self.activeSet(x0)
|
||||
|
||||
@count
|
||||
def projection(self, x):
|
||||
"""Make sure we are feasible."""
|
||||
return np.median(np.c_[self.lower,x,self.upper],axis=1)
|
||||
|
||||
@count
|
||||
def activeSet(self, x):
|
||||
"""If we are on a bound"""
|
||||
return np.logical_or(x == self.lower, x == self.upper)
|
||||
|
||||
@count
|
||||
def inactiveSet(self, x):
|
||||
"""The free variables."""
|
||||
return np.logical_not(self.activeSet(x))
|
||||
|
||||
@count
|
||||
def bindingSet(self, x):
|
||||
"""
|
||||
If we are on a bound and the negative gradient points away from the feasible set.
|
||||
@@ -552,6 +565,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
bind_low = np.logical_and(x == self.upper, self.g <= 0)
|
||||
return np.logical_or(bind_up, bind_low)
|
||||
|
||||
@timeIt
|
||||
def findSearchDirection(self):
|
||||
self.aSet_prev = self.activeSet(self.xc)
|
||||
allBoundsAreActive = sum(self.aSet_prev) == self.xc.size
|
||||
@@ -592,6 +606,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
# aSet_after = self.activeSet(self.xc+p)
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
def _doEndIteration_ProjectedGradient(self, xt):
|
||||
aSet = self.activeSet(xt)
|
||||
bSet = self.bindingSet(xt)
|
||||
@@ -622,6 +637,8 @@ class ProjectedGradient(Minimize, Remember):
|
||||
|
||||
class GaussNewton(Minimize, Remember):
|
||||
name = 'Gauss Newton'
|
||||
|
||||
@timeIt
|
||||
def findSearchDirection(self):
|
||||
return Solver(self.H).solve(-self.g)
|
||||
|
||||
@@ -632,6 +649,7 @@ class InexactGaussNewton(Minimize, Remember):
|
||||
maxIterCG = 10
|
||||
tolCG = 1e-5
|
||||
|
||||
@timeIt
|
||||
def findSearchDirection(self):
|
||||
# TODO: use BFGS as a preconditioner or gauss sidel of the WtW or solve WtW directly
|
||||
p, info = sp.linalg.cg(self.H, -self.g, tol=self.tolCG, maxiter=self.maxIterCG)
|
||||
@@ -640,6 +658,8 @@ class InexactGaussNewton(Minimize, Remember):
|
||||
|
||||
class SteepestDescent(Minimize, Remember):
|
||||
name = 'Steepest Descent'
|
||||
|
||||
@timeIt
|
||||
def findSearchDirection(self):
|
||||
return -self.g
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from SimPEG.utils import sdiag
|
||||
from SimPEG.utils import sdiag, count, timeIt
|
||||
import numpy as np
|
||||
|
||||
class Regularization(object):
|
||||
@@ -40,21 +40,20 @@ class Regularization(object):
|
||||
self._Wz = sdiag(a)*self.mesh.cellGradz
|
||||
return self._Wz
|
||||
|
||||
alpha_s = 1e-6
|
||||
alpha_x = 1
|
||||
alpha_y = 1
|
||||
alpha_z = 1
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh):
|
||||
self.mesh = mesh
|
||||
self._Wx = None
|
||||
self._Wy = None
|
||||
self._Wz = None
|
||||
self.alpha_s = 1e-6
|
||||
self.alpha_x = 1
|
||||
self.alpha_y = 1
|
||||
self.alpha_z = 1
|
||||
|
||||
def pnorm(self, r):
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
@timeIt
|
||||
def modelObj(self, m):
|
||||
mresid = m - self.mref
|
||||
|
||||
@@ -69,6 +68,7 @@ class Regularization(object):
|
||||
|
||||
return mobj
|
||||
|
||||
@timeIt
|
||||
def modelObjDeriv(self, m):
|
||||
"""
|
||||
|
||||
@@ -104,6 +104,7 @@ class Regularization(object):
|
||||
return mobjDeriv
|
||||
|
||||
|
||||
@timeIt
|
||||
def modelObj2Deriv(self, m):
|
||||
mresid = m - self.mref
|
||||
|
||||
|
||||
@@ -60,3 +60,114 @@ def printStoppers(obj, stoppers, pad='', stop='STOP!', done='DONE!'):
|
||||
r = stopper['right'](obj)
|
||||
print pad + stopper['str'] % (l<=r,l,r)
|
||||
print pad + "%s%s%s" % ('-'*25,done,'-'*25)
|
||||
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Counter(object):
|
||||
"""
|
||||
Counter allows anything that calls it to record iterations and
|
||||
timings in a simple way.
|
||||
|
||||
Also has plotting functions that allow quick recalls of data.
|
||||
|
||||
If you want to use this, import *count* or *timeIt* and use them as decorators on class methods.
|
||||
|
||||
.. ::
|
||||
|
||||
class MyClass(object):
|
||||
def __init__(self, url):
|
||||
self.counter = Counter()
|
||||
|
||||
@count
|
||||
def MyMethod(self):
|
||||
pass
|
||||
|
||||
@timeIt
|
||||
def MySecondMethod(self):
|
||||
pass
|
||||
|
||||
c = MyClass('blah')
|
||||
for i in range(100): c.MyMethod()
|
||||
for i in range(300): c.MySecondMethod()
|
||||
c.counter.summary()
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._countList = {}
|
||||
self._timeList = {}
|
||||
|
||||
def count(self, prop):
|
||||
"""
|
||||
Increases the count of the property.
|
||||
"""
|
||||
assert type(prop) is str, 'The property must be a string.'
|
||||
if prop not in self._countList:
|
||||
self._countList[prop] = 0
|
||||
self._countList[prop] += 1
|
||||
|
||||
def countTic(self, prop):
|
||||
"""
|
||||
Times a property call, this is the init call.
|
||||
"""
|
||||
assert type(prop) is str, 'The property must be a string.'
|
||||
if prop not in self._timeList:
|
||||
self._timeList[prop] = []
|
||||
self._timeList[prop].append(-time.time())
|
||||
|
||||
def countToc(self, prop):
|
||||
"""
|
||||
Times a property call, this is the end call.
|
||||
"""
|
||||
assert type(prop) is str, 'The property must be a string.'
|
||||
assert prop in self._timeList, 'The property must already be in the dictionary.'
|
||||
self._timeList[prop][-1] += time.time()
|
||||
|
||||
def summary(self):
|
||||
"""
|
||||
Provides a text summary of the current counters and timers.
|
||||
"""
|
||||
print 'Counters:'
|
||||
for prop in sorted(self._countList):
|
||||
print " {0:<40}: {1:8d}".format(prop,self._countList[prop])
|
||||
print '\nTimes:'+' '*40+'mean sum'
|
||||
for prop in sorted(self._timeList):
|
||||
l = len(self._timeList[prop])
|
||||
a = np.array(self._timeList[prop])
|
||||
print " {0:<40}: {1:4.2e}, {2:4.2e}, {3:4d}x".format(prop,a.mean(),a.sum(),l)
|
||||
|
||||
def count(f):
|
||||
def wrapper(self,*args,**kwargs):
|
||||
counter = getattr(self,'counter',None)
|
||||
if type(counter) is Counter: counter.count(self.__class__.__name__+'.'+f.__name__)
|
||||
out = f(self,*args,**kwargs)
|
||||
return out
|
||||
return wrapper
|
||||
|
||||
def timeIt(f):
|
||||
def wrapper(self,*args,**kwargs):
|
||||
counter = getattr(self,'counter',None)
|
||||
if type(counter) is Counter: counter.countTic(self.__class__.__name__+'.'+f.__name__)
|
||||
out = f(self,*args,**kwargs)
|
||||
if type(counter) is Counter: counter.countToc(self.__class__.__name__+'.'+f.__name__)
|
||||
return out
|
||||
return wrapper
|
||||
if __name__ == '__main__':
|
||||
class MyClass(object):
|
||||
def __init__(self, url):
|
||||
self.counter = Counter()
|
||||
|
||||
@count
|
||||
def MyMethod(self):
|
||||
pass
|
||||
|
||||
@timeIt
|
||||
def MySecondMethod(self):
|
||||
pass
|
||||
|
||||
c = MyClass('blah')
|
||||
for i in range(100): c.MyMethod()
|
||||
for i in range(300): c.MySecondMethod()
|
||||
c.counter.summary()
|
||||
|
||||
Reference in New Issue
Block a user