mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 22:23:39 +08:00
Merge branch 'origin/master'
Conflicts: SimPEG/__init__.py
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg as linalg
|
||||
|
||||
|
||||
class Solver(object):
|
||||
"""docstring for Solver"""
|
||||
def __init__(self, A, doDirect=True, flag=None, options={}):
|
||||
|
||||
assert type(doDirect) is bool, 'doDirect must be a boolean'
|
||||
assert flag in [None, 'L', 'U', 'D'], "flag must be set to None, 'L', 'U', or 'D'"
|
||||
|
||||
self.A = A
|
||||
|
||||
self.dsolve = None
|
||||
self.doDirect = doDirect
|
||||
self.flag = flag
|
||||
self.options = options
|
||||
|
||||
def solve(self, b):
|
||||
if self.flag is None and self.doDirect:
|
||||
return self.solveDirect(b, **self.options)
|
||||
elif self.flag is None and not self.doDirect:
|
||||
return self.solveIter(b, **self.options)
|
||||
elif self.flag == 'U':
|
||||
return self.solveBackward(b)
|
||||
elif self.flag == 'L':
|
||||
return self.solveForward(b)
|
||||
elif self.flag == 'D':
|
||||
return self.solveDiagonal(b)
|
||||
else:
|
||||
raise Exception('Unknown flag.')
|
||||
pass
|
||||
|
||||
def clean(self):
|
||||
"""Cleans up the memory"""
|
||||
del self.dsolve
|
||||
self.dsolve = None
|
||||
|
||||
def solveDirect(self, b, backend='scipy'):
|
||||
assert np.shape(self.A)[1] == np.shape(b)[0], 'Dimension mismatch'
|
||||
|
||||
if self.dsolve is None:
|
||||
self.A = self.A.tocsc() # for efficiency
|
||||
self.dsolve = linalg.factorized(self.A)
|
||||
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
return self.dsolve(b)
|
||||
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
X[:,i] = self.dsolve(b[:,i])
|
||||
|
||||
return X
|
||||
|
||||
def solveIter(self, b, M=None, iterSolver='CG'):
|
||||
pass
|
||||
|
||||
def solveBackward(self, b):
|
||||
pass
|
||||
|
||||
def solveForward(self, b):
|
||||
pass
|
||||
|
||||
def solveDiagonal(self, b):
|
||||
diagA = self.A.diagonal()
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
return b/diagA
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
X[:,i] = b[:,i]/diagA
|
||||
return X
|
||||
+7
-1
@@ -1,5 +1,11 @@
|
||||
import mesh
|
||||
import utils
|
||||
from utils import Solver
|
||||
import mesh
|
||||
import inverse
|
||||
<<<<<<< HEAD
|
||||
from Solver import Solver
|
||||
import visulize
|
||||
=======
|
||||
import forward
|
||||
import regularization
|
||||
>>>>>>> refs/remotes/origin/master
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.forward import Problem, SyntheticProblem, ModelTransforms
|
||||
from SimPEG.tests import checkDerivative
|
||||
from SimPEG.utils import ModelBuilder, sdiag, mkvc
|
||||
from SimPEG import Solver
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import scipy.sparse.linalg as linalg
|
||||
|
||||
|
||||
class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
"""
|
||||
**DCProblem**
|
||||
|
||||
Geophysical DC resistivity problem.
|
||||
|
||||
"""
|
||||
def __init__(self, mesh):
|
||||
super(DCProblem, self).__init__(mesh)
|
||||
self.mesh.setCellGradBC('neumann')
|
||||
|
||||
def reshapeFields(self, u):
|
||||
if len(u.shape) == 1:
|
||||
u = u.reshape([-1, self.RHS.shape[1]], order='F')
|
||||
return u
|
||||
|
||||
def createMatrix(self, m):
|
||||
"""
|
||||
Makes the matrix A(m) for the DC resistivity problem.
|
||||
|
||||
:param numpy.array m: model
|
||||
:rtype: scipy.csc_matrix
|
||||
:return: A(m)
|
||||
|
||||
.. math::
|
||||
c(m,u) = A(m)u - q = G\\text{sdiag}(M(mT(m)))Du - q = 0
|
||||
|
||||
Where M() is the mass matrix and mT is the model transform.
|
||||
"""
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
sigma = self.modelTransform(m)
|
||||
Msig = self.mesh.getFaceMass(sigma)
|
||||
A = D*Msig*G
|
||||
return A.tocsc()
|
||||
|
||||
def dpred(self, m, u=None):
|
||||
"""
|
||||
Predicted data.
|
||||
|
||||
.. math::
|
||||
d_\\text{pred} = Pu(m)
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
|
||||
u = self.reshapeFields(u)
|
||||
|
||||
return mkvc(self.P*u)
|
||||
|
||||
def field(self, m):
|
||||
A = self.createMatrix(m)
|
||||
solve = Solver(A)
|
||||
phi = solve.solve(self.RHS)
|
||||
return mkvc(phi)
|
||||
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:param numpy.array v: vector to multiply
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: Jv
|
||||
|
||||
.. math::
|
||||
c(m,u) = A(m)u - q = G\\text{sdiag}(M(mT(m)))Du - q = 0
|
||||
|
||||
\\nabla_u (A(m)u - q) = A(m)
|
||||
|
||||
\\nabla_m (A(m)u - q) = G\\text{sdiag}(Du)\\nabla_m(M(mT(m)))
|
||||
|
||||
Where M() is the mass matrix and mT is the model transform.
|
||||
|
||||
.. math::
|
||||
J = - P \left( \\nabla_u c(m, u) \\right)^{-1} \\nabla_m c(m, u)
|
||||
|
||||
J(v) = - P ( A(m)^{-1} ( G\\text{sdiag}(Du)\\nabla_m(M(mT(m))) v ) )
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
|
||||
u = self.reshapeFields(u)
|
||||
|
||||
P = self.P
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
A = self.createMatrix(m)
|
||||
Av_dm = self.mesh.getFaceMassDeriv()
|
||||
mT_dm = self.modelTransformDeriv(m)
|
||||
|
||||
dCdu = A
|
||||
|
||||
dCdm = np.empty_like(u)
|
||||
for i, ui in enumerate(u.T): # loop over each column
|
||||
dCdm[:, i] = D * ( sdiag( G * ui ) * ( Av_dm * ( mT_dm * v ) ) )
|
||||
|
||||
solve = Solver(dCdu)
|
||||
Jv = - P * solve.solve(dCdm)
|
||||
return mkvc(Jv)
|
||||
|
||||
def Jt(self, m, v, u=None):
|
||||
"""Takes data, turns it into a model..ish"""
|
||||
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
|
||||
u = self.reshapeFields(u)
|
||||
v = self.reshapeFields(v)
|
||||
|
||||
P = self.P
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
A = self.createMatrix(m)
|
||||
Av_dm = self.mesh.getFaceMassDeriv()
|
||||
mT_dm = self.modelTransformDeriv(m)
|
||||
|
||||
dCdu = A.T
|
||||
solve = Solver(dCdu)
|
||||
|
||||
w = solve.solve(P.T*v)
|
||||
|
||||
Jtv = 0
|
||||
for i, ui in enumerate(u.T): # loop over each column
|
||||
Jtv += sdiag( G * ui ) * ( D.T * w[:,i] )
|
||||
|
||||
Jtv = - mT_dm.T * ( Av_dm.T * Jtv )
|
||||
return Jtv
|
||||
|
||||
|
||||
|
||||
def genTxRxmat(nelec, spacelec, surfloc, elecini, mesh):
|
||||
""" Generate projection matrix (Q) and """
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
elecLocT = elecLocR+1
|
||||
nrx = nelec-1
|
||||
ntx = nelec-1
|
||||
q = np.zeros((mesh.nC, ntx))
|
||||
Q = np.zeros((mesh.nC, nrx))
|
||||
|
||||
for i in range(nrx):
|
||||
|
||||
rxind1 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocR[i]))
|
||||
rxind2 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocR[i+1]))
|
||||
|
||||
txind1 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocT[i]))
|
||||
txind2 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocT[i+1]))
|
||||
|
||||
q[txind1,i] = 1
|
||||
q[txind2,i] = -1
|
||||
Q[rxind1,i] = 1
|
||||
Q[rxind2,i] = -1
|
||||
|
||||
Q = sp.csr_matrix(Q)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
return q, Q, rxmidLoc
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from SimPEG.regularization import Regularization
|
||||
from SimPEG import inverse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Create the mesh
|
||||
h1 = np.ones(20)
|
||||
h2 = np.ones(100)
|
||||
mesh = TensorMesh([h1,h2])
|
||||
|
||||
# Create some parameters for the model
|
||||
sig1 = np.log(1)
|
||||
sig2 = np.log(0.01)
|
||||
|
||||
# Create a synthetic model from a block in a half-space
|
||||
p0 = [5, 10]
|
||||
p1 = [15, 50]
|
||||
condVals = [sig1, sig2]
|
||||
mSynth = ModelBuilder.defineBlockConductivity(p0,p1,mesh.gridCC,condVals)
|
||||
plt.colorbar(mesh.plotImage(mSynth))
|
||||
plt.show()
|
||||
|
||||
# Set up the projection
|
||||
nelec = 50
|
||||
spacelec = 2
|
||||
surfloc = 0.5
|
||||
elecini = 0.5
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
q, Q, rxmidloc = genTxRxmat(nelec, spacelec, surfloc, elecini, mesh)
|
||||
P = Q.T
|
||||
|
||||
# Create some data
|
||||
class syntheticDCProblem(DCProblem, SyntheticProblem):
|
||||
pass
|
||||
|
||||
synthetic = syntheticDCProblem(mesh);
|
||||
synthetic.P = P
|
||||
synthetic.RHS = q
|
||||
dobs, Wd = synthetic.createData(mSynth, std=0.05)
|
||||
|
||||
u = synthetic.field(mSynth)
|
||||
u = synthetic.reshapeFields(u)
|
||||
mesh.plotImage(u[:,10])
|
||||
# plt.show()
|
||||
|
||||
# Now set up the problem to do some minimization
|
||||
problem = DCProblem(mesh)
|
||||
problem.P = P
|
||||
problem.RHS = q
|
||||
problem.dobs = dobs
|
||||
problem.std = dobs*0 + 0.05
|
||||
m0 = mesh.gridCC[:,0]*0+sig2
|
||||
|
||||
opt = inverse.InexactGaussNewton(maxIterLS=20, maxIter=10, tolF=1e-6, tolX=1e-6, tolG=1e-6, maxIterCG=6)
|
||||
reg = Regularization(mesh)
|
||||
inv = inverse.Inversion(problem, reg, opt, beta0=1e4)
|
||||
|
||||
# Check Derivative
|
||||
derChk = lambda m: [inv.dataObj(m), inv.dataObjDeriv(m)]
|
||||
checkDerivative(derChk, mSynth)
|
||||
|
||||
|
||||
|
||||
print inv.dataObj(m0)
|
||||
print inv.dataObj(mSynth)
|
||||
|
||||
m = inv.run(m0)
|
||||
|
||||
plt.colorbar(mesh.plotImage(m))
|
||||
print m
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.forward import Problem, SyntheticProblem
|
||||
from SimPEG.tests import checkDerivative
|
||||
from SimPEG.utils import ModelBuilder, sdiag
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg as linalg
|
||||
import DCutils
|
||||
|
||||
class DCProblem(Problem):
|
||||
"""
|
||||
**DCProblem**
|
||||
|
||||
Geophysical DC resistivity problem.
|
||||
|
||||
"""
|
||||
def __init__(self, mesh):
|
||||
super(DCProblem, self).__init__(mesh)
|
||||
self.mesh.setCellGradBC('neumann')
|
||||
|
||||
def createMatrix(self, m):
|
||||
"""
|
||||
Makes the matrix A(m) for the DC resistivity problem.
|
||||
|
||||
:param numpy.array m: model
|
||||
:rtype: scipy.csc_matrix
|
||||
:return: A(m)
|
||||
|
||||
.. math::
|
||||
c(m,u) = A(m)u - q = G\\text{sdiag}(M(mT(m)))Du - q = 0
|
||||
|
||||
Where M() is the mass matrix and mT is the model transform.
|
||||
"""
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
sigma = self.modelTransform(m)
|
||||
Msig = self.mesh.getFaceMass(sigma)
|
||||
A = D*Msig*G
|
||||
return A.tocsc()
|
||||
|
||||
def field(self, m):
|
||||
A = self.createMatrix(m)
|
||||
solve = linalg.factorized(A)
|
||||
|
||||
nRHSs = self.RHS.shape[1] # Number of RHSs
|
||||
phi = np.zeros((self.mesh.nC, nRHSs)) + np.nan
|
||||
for ii in range(nRHSs):
|
||||
phi[:,ii] = solve(self.RHS[:,ii])
|
||||
|
||||
return phi
|
||||
|
||||
def J(self, m, v, u=None, solve=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:param numpy.array v: vector to multiply
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: Jv
|
||||
|
||||
.. math::
|
||||
c(m,u) = A(m)u - q = G\\text{sdiag}(M(mT(m)))Du - q = 0
|
||||
|
||||
\\nabla_u (A(m)u - q) = A(m)
|
||||
|
||||
\\nabla_m (A(m)u - q) = G\\text{sdiag}(Du)\\nabla_m(M(mT(m)))
|
||||
|
||||
Where M() is the mass matrix and mT is the model transform.
|
||||
|
||||
.. math::
|
||||
J = - P \left( \\nabla_u c(m, u) \\right)^{-1} \\nabla_m c(m, u)
|
||||
|
||||
J(v) = - P ( A(m)^{-1} ( G\\text{sdiag}(Du)\\nabla_m(M(mT(m))) v ) )
|
||||
"""
|
||||
P = self.P
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
A = self.createMatrix(m)
|
||||
Av_dm = self.mesh.getFaceMassDeriv()
|
||||
mT_dm = self.modelTransformDeriv(m)
|
||||
|
||||
dCdu = A
|
||||
dCdm = D * ( sdiag( G * u ) * ( Av_dm * ( mT_dm * v ) ) )
|
||||
|
||||
if solve is None:
|
||||
solve = linalg.factorized(dCdu)
|
||||
|
||||
Jv = - P * solve(dCdm)
|
||||
return Jv
|
||||
|
||||
def Jt(self, m, v, u=None, solve=None):
|
||||
P = self.P
|
||||
D = self.mesh.faceDiv
|
||||
G = self.mesh.cellGrad
|
||||
A = self.createMatrix(m)
|
||||
Av_dm = self.mesh.getFaceMassDeriv()
|
||||
mT_dm = self.modelTransformDeriv(m)
|
||||
|
||||
dCdu = A.T
|
||||
|
||||
if solve is None:
|
||||
solve = linalg.factorized(dCdu.tocsc())
|
||||
w = solve(P.T*v)
|
||||
|
||||
Jtv = - mT_dm.T * ( Av_dm.T * ( sdiag( G * u ) * ( D.T * w ) ) )
|
||||
return Jtv
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Create the mesh
|
||||
h1 = np.ones(100)
|
||||
h2 = np.ones(100)
|
||||
mesh = TensorMesh([h1,h2])
|
||||
|
||||
# Create some parameters for the model
|
||||
sig1 = 1
|
||||
sig2 = 0.01
|
||||
|
||||
# Create a synthetic model from a block in a half-space
|
||||
p0 = [20, 20]
|
||||
p1 = [50, 50]
|
||||
condVals = [sig1, sig2]
|
||||
mSynth = ModelBuilder.defineBlockConductivity(p0,p1,mesh.gridCC,condVals)
|
||||
mesh.plotImage(mSynth, showIt=False)
|
||||
|
||||
|
||||
# Set up the projection
|
||||
nelec = 50
|
||||
spacelec = 2
|
||||
surfloc = 0.5
|
||||
elecini = 0.5
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
q, Q, rxmidloc = DCutils.genTxRxmat(nelec, spacelec, surfloc, elecini, mesh)
|
||||
P = Q.T
|
||||
|
||||
# Create some data
|
||||
class syntheticDCProblem(DCProblem, SyntheticProblem):
|
||||
pass
|
||||
|
||||
synthetic = syntheticDCProblem(mesh);
|
||||
synthetic.P = P
|
||||
synthetic.RHS = q
|
||||
dobs, Wd = synthetic.createData(mSynth, std=0.05)
|
||||
|
||||
u = synthetic.field(mSynth)
|
||||
mesh.plotImage(u[:,10], showIt=True)
|
||||
|
||||
# Now set up the problem to do some minimization
|
||||
problem = DCProblem(mesh)
|
||||
problem.P = P
|
||||
problem.RHS = q
|
||||
problem.W = Wd
|
||||
problem.dobs = dobs
|
||||
m0 = mesh.gridCC[:,0]*0+sig1
|
||||
|
||||
print problem.misfit(m0)
|
||||
print problem.misfit(mSynth)
|
||||
|
||||
# Check Derivative
|
||||
derChk = lambda m: [problem.misfit(m), problem.misfitDeriv(m)]
|
||||
checkDerivative(derChk, mSynth)
|
||||
|
||||
# Adjoint Test
|
||||
u = np.random.rand(mesh.nC)
|
||||
v = np.random.rand(mesh.nC)
|
||||
w = np.random.rand(dobs.shape[0])
|
||||
print w.dot(problem.J(mSynth, v, u=u))
|
||||
print v.dot(problem.Jt(mSynth, w, u=u))
|
||||
@@ -1,29 +0,0 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
def genTxRxmat(nelec, spacelec, surfloc, elecini, mesh):
|
||||
""" Generate projection matrix (Q) and """
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
elecLocT = elecLocR+1
|
||||
nrx = nelec-1
|
||||
ntx = nelec-1
|
||||
q = np.zeros((mesh.nC, ntx))
|
||||
Q = np.zeros((mesh.nC, nrx))
|
||||
|
||||
for i in range(nrx):
|
||||
|
||||
rxind1 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocR[i]))
|
||||
rxind2 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocR[i+1]))
|
||||
|
||||
txind1 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocT[i]))
|
||||
txind2 = np.argwhere((mesh.gridCC[:,0]==surfloc) & (mesh.gridCC[:,1]==elecLocT[i+1]))
|
||||
|
||||
q[txind1,i] = 1
|
||||
q[txind2,i] = -1
|
||||
Q[rxind1,i] = 1
|
||||
Q[rxind2,i] = -1
|
||||
|
||||
Q = sp.csr_matrix(Q)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
return q, Q, rxmidLoc
|
||||
@@ -1,2 +0,0 @@
|
||||
from DCProblem import *
|
||||
from DCutils import *
|
||||
@@ -0,0 +1,89 @@
|
||||
import numpy as np
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.forward import Problem
|
||||
from SimPEG.regularization import Regularization
|
||||
from SimPEG.inverse import *
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class LinearProblem(Problem):
|
||||
"""docstring for LinearProblem"""
|
||||
|
||||
def dpred(self, m, u=None):
|
||||
return self.G.dot(m)
|
||||
|
||||
def J(self, m, v, u=None):
|
||||
return G.dot(v)
|
||||
|
||||
def Jt(self, m, v, u=None):
|
||||
return G.T.dot(v)
|
||||
|
||||
if __name__ == '__main__':
|
||||
N = 100
|
||||
h = np.ones(N)/N
|
||||
M = TensorMesh([h])
|
||||
|
||||
nk = 20
|
||||
jk = np.linspace(1.,20.,nk)
|
||||
p = -0.25
|
||||
q = 0.25
|
||||
|
||||
|
||||
|
||||
g = lambda k: np.exp(p*jk[k]*M.vectorCCx)*np.cos(2*np.pi*q*jk[k]*M.vectorCCx)
|
||||
|
||||
G = np.empty((nk, M.nC))
|
||||
|
||||
for i in range(nk):
|
||||
G[i,:] = g(i)
|
||||
|
||||
|
||||
|
||||
plt.figure(1)
|
||||
for i in range(nk):
|
||||
plt.plot(G[i,:])
|
||||
|
||||
|
||||
m_true = np.zeros(M.nC)
|
||||
m_true[M.vectorCCx > 0.3] = 1.
|
||||
m_true[M.vectorCCx > 0.45] = -0.5
|
||||
m_true[M.vectorCCx > 0.6] = 0
|
||||
|
||||
|
||||
d_true = G.dot(m_true)
|
||||
noise = 0.1 * np.random.rand(d_true.size)
|
||||
|
||||
d_obs = d_true + noise
|
||||
|
||||
# plt.figure(3)
|
||||
# plt.plot(d_true,'-o')
|
||||
# plt.plot(d_obs,'r-o')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
prob = LinearProblem(M)
|
||||
prob.G = G
|
||||
prob.dobs = d_obs
|
||||
prob.std = np.ones_like(d_obs)*0.1
|
||||
|
||||
reg = Regularization(M)
|
||||
|
||||
opt = InexactGaussNewton(maxIter=20)
|
||||
|
||||
inv = Inversion(prob,reg,opt,beta0=1e-4)
|
||||
|
||||
m0 = np.zeros_like(m_true)
|
||||
|
||||
mrec = inv.run(m0)
|
||||
|
||||
|
||||
plt.figure(2)
|
||||
|
||||
plt.plot(M.vectorCCx, m_true, 'b-')
|
||||
plt.plot(M.vectorCCx, mrec, 'r-')
|
||||
|
||||
|
||||
|
||||
plt.show()
|
||||
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
from SimPEG.utils import mkvc, sdiag
|
||||
|
||||
class LogModel(object):
|
||||
"""docstring for LogModel"""
|
||||
def modelTransform(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: numpy.array
|
||||
:return: transformed model
|
||||
|
||||
The modelTransform changes the model into the physical property.
|
||||
|
||||
A common example of this is to invert for electrical conductivity
|
||||
in log space. In this case, your model will be log(sigma) and to
|
||||
get back to sigma, you can take the exponential:
|
||||
|
||||
.. math::
|
||||
|
||||
m = \log{\sigma}
|
||||
|
||||
\exp{m} = \exp{\log{\sigma}} = \sigma
|
||||
"""
|
||||
return np.exp(mkvc(m))
|
||||
|
||||
def modelTransformDeriv(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: scipy.csr_matrix
|
||||
:return: derivative of transformed model
|
||||
|
||||
The modelTransform changes the model into the physical property.
|
||||
The modelTransformDeriv provides the derivative of the modelTransform.
|
||||
|
||||
If the model transform is:
|
||||
|
||||
.. math::
|
||||
|
||||
m = \log{\sigma}
|
||||
|
||||
\exp{m} = \exp{\log{\sigma}} = \sigma
|
||||
|
||||
Then the derivative is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \exp{m}}{\partial m} = \\text{sdiag}(\exp{m})
|
||||
"""
|
||||
return sdiag(np.exp(mkvc(m)))
|
||||
+71
-156
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from SimPEG.utils import mkvc, sdiag
|
||||
import scipy.sparse as sp
|
||||
norm = np.linalg.norm
|
||||
|
||||
|
||||
@@ -49,16 +50,6 @@ class Problem(object):
|
||||
def RHS(self, value):
|
||||
self._RHS = value
|
||||
|
||||
@property
|
||||
def W(self):
|
||||
"""
|
||||
Standard deviation weighting matrix.
|
||||
"""
|
||||
return self._W
|
||||
@W.setter
|
||||
def W(self, value):
|
||||
self._W = value
|
||||
|
||||
@property
|
||||
def P(self):
|
||||
"""
|
||||
@@ -72,6 +63,15 @@ class Problem(object):
|
||||
def P(self, value):
|
||||
self._P = value
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
"""
|
||||
Estimated Standard Deviations.
|
||||
"""
|
||||
return self._std
|
||||
@std.setter
|
||||
def std(self, value):
|
||||
self._std = value
|
||||
|
||||
@property
|
||||
def dobs(self):
|
||||
@@ -83,16 +83,35 @@ class Problem(object):
|
||||
def dobs(self, value):
|
||||
self._dobs = value
|
||||
|
||||
def evalFunction(self, m, doDerivative=True):
|
||||
def dpred(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:param bool doDerivative: do you want to compute the derivative?
|
||||
:rtype: numpy.array
|
||||
:return: Jv
|
||||
"""
|
||||
f = self.misfit(m)
|
||||
Predicted data.
|
||||
|
||||
return f, g, H
|
||||
.. math::
|
||||
d_\\text{pred} = Pu(m)
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
return self.P*u
|
||||
|
||||
def dataResidual(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: float
|
||||
:return: data misfit
|
||||
|
||||
The data misfit:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = \mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data.
|
||||
"""
|
||||
|
||||
return self.dpred(m, u=u) - self.dobs
|
||||
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
@@ -131,10 +150,38 @@ class Problem(object):
|
||||
:rtype: numpy.array
|
||||
:return: JTv
|
||||
|
||||
Transpose of J
|
||||
Effect of transpose of J on a vector v.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def J_approx(self, m, v, u=None):
|
||||
"""
|
||||
|
||||
:param numpy.array m: model
|
||||
:param numpy.array v: vector to multiply
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: Jv
|
||||
|
||||
Approximate effect of J on a vector v
|
||||
|
||||
"""
|
||||
return self.J(m, v, u)
|
||||
|
||||
def Jt_approx(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:param numpy.array v: vector to multiply
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: JTv
|
||||
|
||||
Approximate transpose of J*v
|
||||
|
||||
"""
|
||||
return self.Jt(m, v, u)
|
||||
|
||||
def field(self, m):
|
||||
"""
|
||||
The field given the model.
|
||||
@@ -145,17 +192,6 @@ class Problem(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def dpred(self, m, u=None):
|
||||
"""
|
||||
Predicted data.
|
||||
|
||||
.. math::
|
||||
d_\\text{pred} = Pu(m)
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
return self.P*u
|
||||
|
||||
def modelTransform(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -168,13 +204,8 @@ class Problem(object):
|
||||
in log space. In this case, your model will be log(sigma) and to
|
||||
get back to sigma, you can take the exponential:
|
||||
|
||||
.. math::
|
||||
|
||||
m = \log{\sigma}
|
||||
|
||||
\exp{m} = \exp{\log{\sigma}} = \sigma
|
||||
"""
|
||||
return np.exp(mkvc(m))
|
||||
return m
|
||||
|
||||
def modelTransformDeriv(self, m):
|
||||
"""
|
||||
@@ -184,129 +215,10 @@ class Problem(object):
|
||||
|
||||
The modelTransform changes the model into the physical property.
|
||||
The modelTransformDeriv provides the derivative of the modelTransform.
|
||||
|
||||
If the model transform is:
|
||||
|
||||
.. math::
|
||||
|
||||
m = \log{\sigma}
|
||||
|
||||
\exp{m} = \exp{\log{\sigma}} = \sigma
|
||||
|
||||
Then the derivative is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \exp{m}}{\partial m} = \\text{sdiag}(\exp{m})
|
||||
"""
|
||||
return sdiag(np.exp(mkvc(m)))
|
||||
return sp.eye(m.size)
|
||||
|
||||
def misfit(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: float
|
||||
:return: data misfit
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
"""
|
||||
|
||||
R = self.W*(self.dpred(m, u=u) - self.dobs)
|
||||
R = mkvc(R)
|
||||
return 0.5*R.dot(R)
|
||||
|
||||
def misfitDeriv(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: data misfit derivative
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
If the field, u, is provided, the calculation of the data is fast:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathbf{d}_\\text{pred} = \mathbf{Pu(m)}
|
||||
|
||||
\mathbf{R} = \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs})
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
|
||||
The derivative of this, with respect to the model, is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \mu_\\text{data}}{\partial \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ R}
|
||||
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
|
||||
R = self.W*(self.dpred(m, u=u) - self.dobs)
|
||||
|
||||
dmisfit = 0
|
||||
for i in range(self.RHS.shape[1]): # Loop over each right hand side
|
||||
dmisfit += self.Jt(m, self.W[:,i]*R[:,i], u=u[:,i])
|
||||
|
||||
return dmisfit
|
||||
|
||||
def misfitDerivDeriv(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: data misfit derivative
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
If the field, u, is provided, the calculation of the data is fast:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathbf{d}_\\text{pred} = \mathbf{Pu(m)}
|
||||
|
||||
\mathbf{R} = \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs})
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
|
||||
The derivative of this, with respect to the model, is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \mu_\\text{data}}{\partial \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ R}
|
||||
|
||||
\\frac{\partial^2 \mu_\\text{data}}{\partial^2 \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ W J}
|
||||
|
||||
"""
|
||||
if u is None:
|
||||
u = self.field(m)
|
||||
|
||||
R = self.W*(self.dpred(m, u=u) - self.dobs)
|
||||
|
||||
dmisfit = 0
|
||||
for i in range(self.RHS.shape[1]): # Loop over each right hand side
|
||||
dmisfit += self.Jt(m, self.W[:,i]*R[:,i], u=u[:,i])
|
||||
|
||||
return dmisfit
|
||||
|
||||
|
||||
class SyntheticProblem(object):
|
||||
@@ -337,3 +249,6 @@ class SyntheticProblem(object):
|
||||
eps = np.linalg.norm(mkvc(dobs),2)*1e-5
|
||||
Wd = 1/(abs(dobs)*std+eps)
|
||||
return dobs, Wd
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from Problem import *
|
||||
import DCProblem
|
||||
from LinearProblem import LinearProblem
|
||||
import ModelTransforms
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
class Cooling(object):
|
||||
"""Simple Beta Schedule"""
|
||||
|
||||
beta0 = 1.e6
|
||||
beta_coolingFactor = 5.
|
||||
|
||||
def getBeta(self):
|
||||
if self._beta is None:
|
||||
return beta0
|
||||
return self._beta / beta_coolingFactor
|
||||
@@ -0,0 +1,221 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from SimPEG.utils import sdiag, mkvc
|
||||
|
||||
class Inversion(object):
|
||||
"""docstring for Inversion"""
|
||||
|
||||
maxIter = 10
|
||||
name = 'SimPEG Inversion'
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
self.prob = prob
|
||||
self.reg = reg
|
||||
self.opt = opt
|
||||
self.opt.parent = self
|
||||
self.setKwargs(**kwargs)
|
||||
|
||||
def setKwargs(self, **kwargs):
|
||||
"""Sets key word arguments (kwargs) that are present in the object, throw an error if they don't exist."""
|
||||
for attr in kwargs:
|
||||
if hasattr(self, attr):
|
||||
setattr(self, attr, kwargs[attr])
|
||||
else:
|
||||
raise Exception('%s attr is not recognized' % attr)
|
||||
|
||||
def printInit(self):
|
||||
print "%s %s %s" % ('='*22, self.name, '='*22)
|
||||
print " # beta phi_d phi_m f norm(dJ) #LS"
|
||||
print "%s" % '-'*62
|
||||
|
||||
def printIter(self):
|
||||
print "%3d %1.2e %1.2e %1.2e %1.2e %1.2e %3d" % (self.opt._iter, self._beta, self._phi_d_last, self._phi_m_last, self.opt.f, np.linalg.norm(self.opt.g), self.opt._iterLS)
|
||||
|
||||
@property
|
||||
def Wd(self):
|
||||
"""
|
||||
Standard deviation weighting matrix.
|
||||
"""
|
||||
if getattr(self,'_Wd',None) is None:
|
||||
eps = np.linalg.norm(mkvc(self.prob.dobs),2)*1e-5
|
||||
self._Wd = 1/(abs(self.prob.dobs)*self.prob.std+eps)
|
||||
return self._Wd
|
||||
|
||||
@property
|
||||
def phi_d_target(self):
|
||||
"""
|
||||
target for phi_d
|
||||
|
||||
By default this is the number of data.
|
||||
|
||||
Note that we do not set the target if it is None, but we return the default value.
|
||||
"""
|
||||
if getattr(self, '_phi_d_target', None) is None:
|
||||
return self.prob.dobs.size #
|
||||
return self._phi_d_target
|
||||
@phi_d_target.setter
|
||||
def phi_d_target(self, value):
|
||||
self._phi_d_target = value
|
||||
|
||||
def run(self, m0):
|
||||
m = m0
|
||||
self._iter = 0
|
||||
self._beta = None
|
||||
while True:
|
||||
self._beta = self.getBeta()
|
||||
m = self.opt.minimize(self.evalFunction,m)
|
||||
if self.stoppingCriteria(): break
|
||||
self._iter += 1
|
||||
return m
|
||||
|
||||
beta0 = 1.e2
|
||||
beta_coolingFactor = 5.
|
||||
|
||||
def getBeta(self):
|
||||
if self._beta is None:
|
||||
return self.beta0
|
||||
return self._beta / self.beta_coolingFactor
|
||||
|
||||
def stoppingCriteria(self):
|
||||
self._STOP = np.zeros(2,dtype=bool)
|
||||
self._STOP[0] = self._iter >= self.maxIter
|
||||
self._STOP[1] = self._phi_d_last <= self.phi_d_target
|
||||
return np.any(self._STOP)
|
||||
|
||||
|
||||
def evalFunction(self, m, return_g=True, return_H=True):
|
||||
|
||||
u = self.prob.field(m)
|
||||
phi_d = self.dataObj(m, u)
|
||||
phi_m = self.reg.modelObj(m)
|
||||
|
||||
self._phi_d_last = phi_d
|
||||
self._phi_m_last = phi_m
|
||||
|
||||
f = phi_d + self._beta * phi_m
|
||||
|
||||
out = (f,)
|
||||
if return_g:
|
||||
phi_dDeriv = self.dataObjDeriv(m, u=u)
|
||||
phi_mDeriv = self.reg.modelObjDeriv(m)
|
||||
|
||||
g = phi_dDeriv + self._beta * phi_mDeriv
|
||||
out += (g,)
|
||||
|
||||
if return_H:
|
||||
def H_fun(v):
|
||||
phi_d2Deriv = self.dataObj2Deriv(m, v, u=u)
|
||||
phi_m2Deriv = self.reg.modelObj2Deriv(m)*v
|
||||
|
||||
return phi_d2Deriv + self._beta * phi_m2Deriv
|
||||
|
||||
operator = sp.linalg.LinearOperator( (m.size, m.size), H_fun, dtype=float )
|
||||
out += (operator,)
|
||||
return out
|
||||
|
||||
|
||||
def dataObj(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: float
|
||||
:return: data misfit
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
"""
|
||||
# TODO: ensure that this is a data is vector and Wd is a matrix.
|
||||
R = self.Wd*self.prob.dataResidual(m, u=u)
|
||||
R = mkvc(R)
|
||||
return 0.5*np.vdot(R, R)
|
||||
|
||||
def dataObjDeriv(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: data misfit derivative
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
If the field, u, is provided, the calculation of the data is fast:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathbf{d}_\\text{pred} = \mathbf{Pu(m)}
|
||||
|
||||
\mathbf{R} = \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs})
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
|
||||
The derivative of this, with respect to the model, is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \mu_\\text{data}}{\partial \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ R}
|
||||
|
||||
"""
|
||||
if u is None:
|
||||
u = self.prob.field(m)
|
||||
|
||||
R = self.Wd*self.prob.dataResidual(m, u=u)
|
||||
|
||||
dmisfit = self.prob.Jt(m, self.Wd * R, u=u)
|
||||
|
||||
return dmisfit
|
||||
|
||||
def dataObj2Deriv(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
:param numpy.array u: fields
|
||||
:rtype: numpy.array
|
||||
:return: data misfit derivative
|
||||
|
||||
The data misfit using an l_2 norm is:
|
||||
|
||||
.. math::
|
||||
|
||||
\mu_\\text{data} = {1\over 2}\left| \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs}) \\right|_2^2
|
||||
|
||||
If the field, u, is provided, the calculation of the data is fast:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathbf{d}_\\text{pred} = \mathbf{Pu(m)}
|
||||
|
||||
\mathbf{R} = \mathbf{W} \circ (\mathbf{d}_\\text{pred} - \mathbf{d}_\\text{obs})
|
||||
|
||||
Where P is a projection matrix that brings the field on the full domain to the data measurement locations;
|
||||
u is the field of interest; d_obs is the observed data; and W is the weighting matrix.
|
||||
|
||||
The derivative of this, with respect to the model, is:
|
||||
|
||||
.. math::
|
||||
|
||||
\\frac{\partial \mu_\\text{data}}{\partial \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ R}
|
||||
|
||||
\\frac{\partial^2 \mu_\\text{data}}{\partial^2 \mathbf{m}} = \mathbf{J}^\\top \mathbf{W \circ W J}
|
||||
|
||||
"""
|
||||
if u is None:
|
||||
u = self.prob.field(m)
|
||||
|
||||
R = self.Wd*self.prob.dataResidual(m, u=u)
|
||||
|
||||
# TODO: abstract to different norms a little cleaner.
|
||||
# \/ it goes here. in l2 it is the identity.
|
||||
dmisfit = self.prob.Jt_approx(m, self.Wd * self.Wd * self.prob.J_approx(m, v, u=u), u=u)
|
||||
|
||||
return dmisfit
|
||||
|
||||
+275
-38
@@ -2,54 +2,154 @@ import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from SimPEG.utils import mkvc, sdiag
|
||||
norm = np.linalg.norm
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import Solver
|
||||
|
||||
try:
|
||||
from pubsub import pub
|
||||
doPub = True
|
||||
except Exception, e:
|
||||
print 'Warning: you may not have the required pubsub installed, use pypubsub. You will not be able to listen to events.'
|
||||
doPub = False
|
||||
|
||||
|
||||
|
||||
class Minimize(object):
|
||||
"""docstring for Minimize"""
|
||||
"""
|
||||
|
||||
Minimize is a general class for derivative based optimization.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
name = "GeneralOptimizationAlgorithm"
|
||||
|
||||
maxIter = 20
|
||||
maxIterLS = 10
|
||||
maxStep = np.inf
|
||||
LSreduction = 1e-4
|
||||
LSshorten = 0.5
|
||||
tolF = 1e-4
|
||||
tolX = 1e-4
|
||||
tolG = 1e-4
|
||||
eps = 1e-16
|
||||
tolF = 1e-1
|
||||
tolX = 1e-1
|
||||
tolG = 1e-1
|
||||
eps = 1e-5
|
||||
|
||||
def __init__(self, problem, **kwargs):
|
||||
self.problem = problem
|
||||
def __init__(self, **kwargs):
|
||||
self._id = int(np.random.rand()*1e6) # create a unique identifier to this program to be used in pubsub
|
||||
self.setKwargs(**kwargs)
|
||||
|
||||
def setKwargs(self, **kwargs):
|
||||
# Set the variables, throw an error if they don't exist.
|
||||
"""Sets key word arguments (kwargs) that are present in the object, throw an error if they don't exist."""
|
||||
for attr in kwargs:
|
||||
if hasattr(self, attr):
|
||||
setattr(self, attr, kwargs[attr])
|
||||
else:
|
||||
raise Exception('%s attr is not recognized' % attr)
|
||||
|
||||
def minimize(self, x0):
|
||||
def minimize(self, evalFunction, x0):
|
||||
"""
|
||||
Minimizes the function (evalFunction) starting at the location x0.
|
||||
|
||||
:param def evalFunction: function handle that evaluates: f, g, H = F(x)
|
||||
:param numpy.ndarray x0: starting location
|
||||
:rtype: numpy.ndarray
|
||||
:return: x, the last iterate of the optimization algorithm
|
||||
|
||||
evalFunction is a function handle::
|
||||
|
||||
(f[, g][, H]) = evalFunction(x, return_g=False, return_H=False )
|
||||
|
||||
|
||||
Events are fired with the following inputs via pypubsub::
|
||||
|
||||
Minimize.printInit (minimize)
|
||||
Minimize.evalFunction (minimize, f, g, H)
|
||||
Minimize.printIter (minimize)
|
||||
Minimize.searchDirection (minimize, p)
|
||||
Minimize.scaleSearchDirection (minimize, p)
|
||||
Minimize.modifySearchDirection (minimize, xt, passLS)
|
||||
Minimize.endIteration (minimize, xt)
|
||||
Minimize.printDone (minimize)
|
||||
|
||||
To hook into one of these events (must have pypubsub installed)::
|
||||
|
||||
from pubsub import pub
|
||||
def listener(minimize,p):
|
||||
print 'The search direction is: ', p
|
||||
pub.subscribe(listener, 'Minimize.searchDirection')
|
||||
|
||||
You can use pubsub communication to debug your code, it is not used internally.
|
||||
|
||||
|
||||
The algorithm for general minimization is as follows::
|
||||
|
||||
startup(x0)
|
||||
printInit()
|
||||
|
||||
while True:
|
||||
f, g, H = evalFunction(xc)
|
||||
printIter()
|
||||
if stoppingCriteria(): break
|
||||
p = findSearchDirection()
|
||||
p = scaleSearchDirection(p)
|
||||
xt, passLS = modifySearchDirection(p)
|
||||
if not passLS:
|
||||
xt, caught = modifySearchDirectionBreak(p)
|
||||
if not caught: return xc
|
||||
doEndIteration(xt)
|
||||
|
||||
printDone()
|
||||
return xc
|
||||
"""
|
||||
self.evalFunction = evalFunction
|
||||
self.startup(x0)
|
||||
self.printInit()
|
||||
|
||||
while True:
|
||||
self.f, self.g, self.H = self.evalFunction(self.xc)
|
||||
self.f, self.g, self.H = evalFunction(self.xc, return_g=True, return_H=True)
|
||||
if doPub: pub.sendMessage('Minimize.evalFunction', minimize=self, f=self.f, g=self.g, H=self.H)
|
||||
self.printIter()
|
||||
if self.stoppingCriteria(): break
|
||||
p = self.findSearchDirection()
|
||||
xt, passLS = self.linesearch(p)
|
||||
if doPub: pub.sendMessage('Minimize.searchDirection', minimize=self, p=p)
|
||||
p = self.scaleSearchDirection(p)
|
||||
if doPub: pub.sendMessage('Minimize.scaleSearchDirection', minimize=self, p=p)
|
||||
xt, passLS = self.modifySearchDirection(p)
|
||||
if doPub: pub.sendMessage('Minimize.modifySearchDirection', minimize=self, xt=xt, passLS=passLS)
|
||||
if not passLS:
|
||||
xt = self.linesearchBreak(p)
|
||||
xt, caught = self.modifySearchDirectionBreak(p)
|
||||
if not caught: return self.xc
|
||||
self.doEndIteration(xt)
|
||||
if doPub: pub.sendMessage('Minimize.endIteration', minimize=self, xt=xt)
|
||||
|
||||
self.printDone()
|
||||
|
||||
return self.xc
|
||||
|
||||
@property
|
||||
def parent(self):
|
||||
"""
|
||||
This is the parent of the optimization routine.
|
||||
"""
|
||||
return getattr(self, '_parent', None)
|
||||
@parent.setter
|
||||
def parent(self, value):
|
||||
self._parent = value
|
||||
|
||||
def startup(self, x0):
|
||||
"""
|
||||
**startup** is called at the start of any new minimize call.
|
||||
|
||||
This will set::
|
||||
|
||||
x0 = x0
|
||||
xc = x0
|
||||
_iter = _iterLS = 0
|
||||
|
||||
:param numpy.ndarray x0: initial x
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
self._iter = 0
|
||||
self._iterLS = 0
|
||||
self._STOP = np.zeros((5,1),dtype=bool)
|
||||
@@ -59,29 +159,57 @@ class Minimize(object):
|
||||
self.xOld = x0
|
||||
|
||||
def printInit(self):
|
||||
"""
|
||||
**printInit** is called at the beginning of the optimization routine.
|
||||
|
||||
If there is a parent object, printInit will check for a
|
||||
parent.printInit function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printInit', minimize=self)
|
||||
if self.parent is not None and hasattr(self.parent, 'printInit'):
|
||||
self.parent.printInit()
|
||||
return
|
||||
print "%s %s %s" % ('='*22, self.name, '='*22)
|
||||
print "iter\tJc\t\tnorm(dJ)\tLS"
|
||||
print "%s" % '-'*57
|
||||
|
||||
def printIter(self):
|
||||
"""
|
||||
**printIter** is called directly after function evaluations.
|
||||
|
||||
If there is a parent object, printIter will check for a
|
||||
parent.printIter function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printIter', minimize=self)
|
||||
if self.parent is not None and hasattr(self.parent, 'printIter'):
|
||||
self.parent.printIter()
|
||||
return
|
||||
print "%3d\t%1.2e\t%1.2e\t%d" % (self._iter, self.f, norm(self.g), self._iterLS)
|
||||
|
||||
def printDone(self):
|
||||
"""
|
||||
**printDone** is called at the end of the optimization routine.
|
||||
|
||||
If there is a parent object, printDone will check for a
|
||||
parent.printDone function and call that.
|
||||
|
||||
"""
|
||||
if doPub: pub.sendMessage('Minimize.printDone', minimize=self)
|
||||
if self.parent is not None and hasattr(self.parent, 'printDone'):
|
||||
self.parent.printDone()
|
||||
return
|
||||
print "%s STOP! %s" % ('-'*25,'-'*25)
|
||||
print "%d : |fc-fOld| = %1.4e <= tolF*(1+|fStop|) = %1.4e" % (self._STOP[0], abs(self.f-self.fOld), self.tolF*(1+abs(self.fStop)))
|
||||
print "%d : |xc-xOld| = %1.4e <= tolX*(1+|x0|) = %1.4e" % (self._STOP[1], norm(self.xc-self.xOld), self.tolX*(1+norm(self.x0)))
|
||||
# TODO: put controls on gradient value, min model update, and function value
|
||||
if self._iter > 0:
|
||||
print "%d : |fc-fOld| = %1.4e <= tolF*(1+|fStop|) = %1.4e" % (self._STOP[0], abs(self.f-self.fOld), self.tolF*(1+abs(self.fStop)))
|
||||
print "%d : |xc-xOld| = %1.4e <= tolX*(1+|x0|) = %1.4e" % (self._STOP[1], norm(self.xc-self.xOld), self.tolX*(1+norm(self.x0)))
|
||||
print "%d : |g| = %1.4e <= tolG*(1+|fStop|) = %1.4e" % (self._STOP[2], norm(self.g), self.tolG*(1+abs(self.fStop)))
|
||||
print "%d : |g| = %1.4e <= 1e3*eps = %1.4e" % (self._STOP[3], norm(self.g), 1e3*self.eps)
|
||||
print "%d : iter = %3d\t <= maxIter\t = %3d" % (self._STOP[4], self._iter, self.maxIter)
|
||||
print "%s DONE! %s\n" % ('='*25,'='*25)
|
||||
|
||||
def evalFunction(self, x, doDerivative=True):
|
||||
f, g, H = self.problem(x)
|
||||
return f, g, H
|
||||
|
||||
def findSearchDirection(self):
|
||||
return -self.g
|
||||
|
||||
def stoppingCriteria(self):
|
||||
if self._iter == 0:
|
||||
self.fStop = self.f # Save this for stopping criteria
|
||||
@@ -94,14 +222,87 @@ class Minimize(object):
|
||||
self._STOP[4] = self._iter >= self.maxIter
|
||||
return all(self._STOP[0:3]) | any(self._STOP[3:])
|
||||
|
||||
def linesearch(self, p):
|
||||
def projection(self, p):
|
||||
"""
|
||||
projects the search direction.
|
||||
|
||||
by default, no projection is applied.
|
||||
|
||||
:param numpy.ndarray p: searchDirection
|
||||
:rtype: numpy.ndarray
|
||||
:return: p, projected search direction
|
||||
"""
|
||||
return p
|
||||
|
||||
def findSearchDirection(self):
|
||||
"""
|
||||
**findSearchDirection** should return an approximation of:
|
||||
|
||||
.. math::
|
||||
|
||||
H p = - g
|
||||
|
||||
Where you are solving for the search direction, p
|
||||
|
||||
The default is:
|
||||
|
||||
.. math::
|
||||
|
||||
H = I
|
||||
|
||||
p = - g
|
||||
|
||||
And corresponds to SteepestDescent.
|
||||
|
||||
The latest function evaluations are present in::
|
||||
|
||||
self.f, self.g, self.H
|
||||
|
||||
:rtype: numpy.ndarray
|
||||
:return: p, Search Direction
|
||||
"""
|
||||
return -self.g
|
||||
|
||||
def scaleSearchDirection(self, p):
|
||||
"""
|
||||
**scaleSearchDirection** should scale the search direction if appropriate.
|
||||
|
||||
Set the parameter **maxStep** in the minimize object, to scale back the gradient to a maximum size.
|
||||
|
||||
:param numpy.ndarray p: searchDirection
|
||||
:rtype: numpy.ndarray
|
||||
:return: p, Scaled Search Direction
|
||||
"""
|
||||
|
||||
if self.maxStep < np.abs(p.max()):
|
||||
p = self.maxStep*p/np.abs(p.max())
|
||||
return p
|
||||
|
||||
def modifySearchDirection(self, p):
|
||||
"""
|
||||
**modifySearchDirection** changes the search direction based on some sort of linesearch or trust-region criteria.
|
||||
|
||||
By default, an Armijo backtracking linesearch is preformed with the following parameters:
|
||||
|
||||
* maxIterLS, the maximum number of linesearch iterations
|
||||
* LSreduction, the expected reduction expected, default: 1e-4
|
||||
* LSshorten, how much the step is reduced, default: 0.5
|
||||
|
||||
If the linesearch is completed, and a descent direction is found, passLS is returned as True.
|
||||
|
||||
Else, a modifySearchDirectionBreak call is preformed.
|
||||
|
||||
:param numpy.ndarray p: searchDirection
|
||||
:rtype: numpy.ndarray,bool
|
||||
:return: (xt, passLS)
|
||||
"""
|
||||
# Armijo linesearch
|
||||
descent = np.inner(self.g, p)
|
||||
t = 1
|
||||
iterLS = 0
|
||||
while iterLS < self.maxIterLS:
|
||||
xt = self.xc + t*p
|
||||
ft, temp, temp = self.evalFunction(xt, doDerivative=False)
|
||||
xt = self.projection(self.xc + t*p)
|
||||
ft = self.evalFunction(xt, return_g=False, return_H=False)
|
||||
if ft < self.f + t*self.LSreduction*descent:
|
||||
break
|
||||
iterLS += 1
|
||||
@@ -110,10 +311,37 @@ class Minimize(object):
|
||||
self._iterLS = iterLS
|
||||
return xt, iterLS < self.maxIterLS
|
||||
|
||||
def linesearchBreak(self, p):
|
||||
raise Exception('The linesearch got broken. Boo.')
|
||||
def modifySearchDirectionBreak(self, p):
|
||||
"""
|
||||
Code is called if modifySearchDirection fails
|
||||
to find a descent direction.
|
||||
|
||||
The search direction is passed as input and
|
||||
this function must pass back both a new searchDirection,
|
||||
and if the searchDirection break has been caught.
|
||||
|
||||
By default, no additional work is done, and the
|
||||
evalFunction returns a False indicating the break was not caught.
|
||||
|
||||
:param numpy.ndarray p: searchDirection
|
||||
:rtype: numpy.ndarray,bool
|
||||
:return: (xt, breakCaught)
|
||||
"""
|
||||
print 'The linesearch got broken. Boo.'
|
||||
return p, False
|
||||
|
||||
def doEndIteration(self, xt):
|
||||
"""
|
||||
**doEndIteration** is called at the end of each minimize iteration.
|
||||
|
||||
By default, function values and x locations are shuffled to store 1 past iteration in memory.
|
||||
|
||||
self.xc must be updated in this code.
|
||||
|
||||
:param numpy.ndarray xt: tested new iterate that ensures a descent direction.
|
||||
:rtype: None
|
||||
:return: None
|
||||
"""
|
||||
# store old values
|
||||
self.fOld = self.f
|
||||
self.xOld, self.xc = self.xc, xt
|
||||
@@ -123,7 +351,19 @@ class Minimize(object):
|
||||
class GaussNewton(Minimize):
|
||||
name = 'GaussNewton'
|
||||
def findSearchDirection(self):
|
||||
return np.linalg.solve(self.H,-self.g)
|
||||
return Solver(self.H).solve(-self.g)
|
||||
|
||||
|
||||
class InexactGaussNewton(Minimize):
|
||||
name = 'InexactGaussNewton'
|
||||
|
||||
maxIterCG = 10
|
||||
tolCG = 1e-5
|
||||
|
||||
def findSearchDirection(self):
|
||||
# TODO: use BFGS as a preconditioner or gauss sidel of the WtW or solve WtW directly
|
||||
p, info = sp.linalg.cg(self.H, -self.g, tol=self.tolCG, maxiter=self.maxIterCG)
|
||||
return p
|
||||
|
||||
|
||||
class SteepestDescent(Minimize):
|
||||
@@ -133,18 +373,15 @@ class SteepestDescent(Minimize):
|
||||
|
||||
if __name__ == '__main__':
|
||||
from SimPEG.tests import Rosenbrock, checkDerivative
|
||||
import matplotlib.pyplot as plt
|
||||
x0 = np.array([2.6, 3.7])
|
||||
checkDerivative(Rosenbrock, x0, plotIt=False)
|
||||
xOpt = GaussNewton(Rosenbrock, maxIter=20).minimize(x0)
|
||||
|
||||
def listener1(minimize,p):
|
||||
print 'hi: ', p
|
||||
if doPub: pub.subscribe(listener1, 'Minimize.searchDirection')
|
||||
|
||||
xOpt = GaussNewton(maxIter=20,tolF=1e-10,tolX=1e-10,tolG=1e-10).minimize(Rosenbrock,x0)
|
||||
print "xOpt=[%f, %f]" % (xOpt[0], xOpt[1])
|
||||
xOpt = SteepestDescent(Rosenbrock, maxIter=20, maxIterLS=15).minimize(x0)
|
||||
xOpt = SteepestDescent(maxIter=30, maxIterLS=15,tolF=1e-10,tolX=1e-10,tolG=1e-10).minimize(Rosenbrock, x0)
|
||||
print "xOpt=[%f, %f]" % (xOpt[0], xOpt[1])
|
||||
|
||||
def simplePass(x):
|
||||
return np.sin(x), sdiag(np.cos(x))
|
||||
|
||||
def simpleFail(x):
|
||||
return np.sin(x), -sdiag(np.cos(x))
|
||||
|
||||
checkDerivative(simplePass, np.random.randn(5), plotIt=False)
|
||||
checkDerivative(simpleFail, np.random.randn(5), plotIt=False)
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from Optimize import *
|
||||
from Inversion import *
|
||||
import BetaSchedule
|
||||
|
||||
@@ -2,6 +2,7 @@ import numpy as np
|
||||
from SimPEG.utils import mkvc
|
||||
|
||||
|
||||
|
||||
class BaseMesh(object):
|
||||
"""
|
||||
BaseMesh does all the counting you don't want to do.
|
||||
@@ -216,6 +217,12 @@ class BaseMesh(object):
|
||||
|
||||
:rtype: int
|
||||
:return: nC
|
||||
|
||||
.. plot::
|
||||
|
||||
from SimPEG.mesh import TensorMesh
|
||||
import numpy as np
|
||||
TensorMesh([np.ones(n) for n in [2,3]]).plotGrid(centers=True,showIt=True)
|
||||
"""
|
||||
fget = lambda self: np.prod(self.n)
|
||||
return locals()
|
||||
@@ -270,6 +277,12 @@ class BaseMesh(object):
|
||||
|
||||
:rtype: int
|
||||
:return: nN
|
||||
|
||||
.. plot::
|
||||
|
||||
from SimPEG.mesh import TensorMesh
|
||||
import numpy as np
|
||||
TensorMesh([np.ones(n) for n in [2,3]]).plotGrid(nodes=True,showIt=True)
|
||||
"""
|
||||
fget = lambda self: np.prod(self.n + 1)
|
||||
return locals()
|
||||
@@ -324,6 +337,12 @@ class BaseMesh(object):
|
||||
|
||||
:rtype: numpy.array (dim, )
|
||||
:return: [prod(nEx), prod(nEy), prod(nEz)]
|
||||
|
||||
.. plot::
|
||||
|
||||
from SimPEG.mesh import TensorMesh
|
||||
import numpy as np
|
||||
TensorMesh([np.ones(n) for n in [2,3]]).plotGrid(edges=True,showIt=True)
|
||||
"""
|
||||
fget = lambda self: np.array([np.prod(x) for x in [self.nEx, self.nEy, self.nEz] if not x is None])
|
||||
return locals()
|
||||
@@ -378,6 +397,12 @@ class BaseMesh(object):
|
||||
|
||||
:rtype: numpy.array (dim, )
|
||||
:return: [prod(nFx), prod(nFy), prod(nFz)]
|
||||
|
||||
.. plot::
|
||||
|
||||
from SimPEG.mesh import TensorMesh
|
||||
import numpy as np
|
||||
TensorMesh([np.ones(n) for n in [2,3]]).plotGrid(faces=True,showIt=True)
|
||||
"""
|
||||
fget = lambda self: np.array([np.prod(x) for x in [self.nFx, self.nFy, self.nFz] if not x is None])
|
||||
return locals()
|
||||
|
||||
+12
-12
@@ -5,8 +5,8 @@ from SimPEG.utils import mkvc, ndgrid, sdiag
|
||||
|
||||
class Cyl1DMesh(object):
|
||||
"""
|
||||
Cyl1DMesh is a mesh class for cylindrically symmetric 1D problems
|
||||
"""
|
||||
Cyl1DMesh is a mesh class for cylindrically symmetric 1D problems
|
||||
"""
|
||||
|
||||
_meshType = 'CYL1D'
|
||||
|
||||
@@ -20,7 +20,7 @@ class Cyl1DMesh(object):
|
||||
assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i)
|
||||
|
||||
# Ensure h contains 1D vectors
|
||||
self._h = [mkvc(x) for x in h]
|
||||
self._h = [mkvc(x.astype(float)) for x in h]
|
||||
|
||||
if z0 is None:
|
||||
z0 = 0
|
||||
@@ -146,7 +146,7 @@ class Cyl1DMesh(object):
|
||||
|
||||
def vectorCCz():
|
||||
doc = "Cell centered grid vector (1D) in the z direction"
|
||||
fget = lambda self: self.hz.cumsum() - self.hz/2 + self._z0
|
||||
fget = lambda self: self.hz.cumsum() - self.hz/2 + self._z0
|
||||
return locals()
|
||||
vectorCCz = property(**vectorCCz())
|
||||
|
||||
@@ -177,7 +177,7 @@ class Cyl1DMesh(object):
|
||||
self._gridFr = ndgrid([self.vectorNr, self.vectorCCz])
|
||||
return self._gridFr
|
||||
return locals()
|
||||
_gridFr = None
|
||||
_gridFr = None
|
||||
gridFr = property(**gridFr())
|
||||
|
||||
def gridFz():
|
||||
@@ -187,7 +187,7 @@ class Cyl1DMesh(object):
|
||||
self._gridFz = ndgrid([self.vectorCCr, self.vectorNz])
|
||||
return self._gridFz
|
||||
return locals()
|
||||
_gridFz = None
|
||||
_gridFz = None
|
||||
gridFz = property(**gridFz())
|
||||
|
||||
####################################################
|
||||
@@ -350,23 +350,23 @@ class Cyl1DMesh(object):
|
||||
np.all(loc[:,1]<=self.vectorNz.max()), \
|
||||
"Points outside of mesh"
|
||||
|
||||
|
||||
|
||||
if locType=='fz':
|
||||
Q = sp.lil_matrix((loc.shape[0], self.nF), dtype=float)
|
||||
|
||||
for i, iloc in enumerate(loc):
|
||||
# Point is on a z-interface
|
||||
if np.any(np.abs(self.vectorNz-iloc[1])<0.001):
|
||||
if np.any(np.abs(self.vectorNz-iloc[1])<0.001):
|
||||
dFz = self.gridFz-iloc #Distance to z faces
|
||||
dFz[dFz[:,0]>0,:] = np.inf #Looking for next face to the left...
|
||||
indL = np.argmin(np.sum(dFz**2, axis=1)) #Closest one
|
||||
if self.gridFz[indL,0] == self.vectorCCr.max(): #Point in outer half cell (linear extrapolation)
|
||||
zFL = self.gridFz[indL,:]
|
||||
zFLL = self.gridFz[indL-1,:]
|
||||
zFL = self.gridFz[indL,:]
|
||||
zFLL = self.gridFz[indL-1,:]
|
||||
Q[i, indL+self.nFr] = (iloc[0] - zFLL[0])/(zFL[0] - zFLL[0])
|
||||
Q[i, indL+self.nFr-1] = -(iloc[0] - zFL[0])/(zFL[0] - zFLL[0])
|
||||
else:
|
||||
zFL = self.gridFz[indL,:]
|
||||
zFL = self.gridFz[indL,:]
|
||||
zFR = self.gridFz[indL+1,:]
|
||||
Q[i,indL+self.nFr] = (zFR[0] - iloc[0])/(zFR[0] - zFL[0])
|
||||
Q[i,indL+self.nFr+1] = (iloc[0] - zFL[0])/(zFR[0] - zFL[0])
|
||||
@@ -400,7 +400,7 @@ class Cyl1DMesh(object):
|
||||
Q[i, indAL+self.nFr-1] = -(dzB/DZ)*(drL/DR)
|
||||
Q[i, indAL+self.nFr] = (dzB/DZ)*(drLL/DR)
|
||||
else:
|
||||
indBR = indBL+1 # Face below and to the right
|
||||
indBR = indBL+1 # Face below and to the right
|
||||
indAR = indAL + 1 # Face above and to the right
|
||||
zF_BR = self.gridFz[indBR,:]
|
||||
|
||||
|
||||
@@ -161,6 +161,68 @@ class DiffOperators(object):
|
||||
_cellGrad = None
|
||||
cellGrad = property(**cellGrad())
|
||||
|
||||
def cellGradx():
|
||||
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
|
||||
|
||||
def fget(self):
|
||||
if getattr(self, '_cellGradx', None) is None:
|
||||
BC = ['neumann', 'neumann']
|
||||
n = self.n
|
||||
if(self.dim == 1):
|
||||
G1 = ddxCellGrad(n[0], BC)
|
||||
elif(self.dim == 2):
|
||||
G1 = sp.kron(speye(n[1]), ddxCellGrad(n[0], BC))
|
||||
elif(self.dim == 3):
|
||||
G1 = kron3(speye(n[2]), speye(n[1]), ddxCellGrad(n[0], BC))
|
||||
# Compute areas of cell faces & volumes
|
||||
S = self.r(self.area, 'F','Fx', 'V')
|
||||
V = self.vol
|
||||
self._cellGradx = sdiag(S)*G1*sdiag(1/V)
|
||||
return self._cellGradx
|
||||
return locals()
|
||||
cellGradx = property(**cellGradx())
|
||||
|
||||
|
||||
def cellGrady():
|
||||
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
|
||||
def fget(self):
|
||||
if self.dim < 2:
|
||||
return None
|
||||
if getattr(self, '_cellGrady', None) is None:
|
||||
BC = ['neumann', 'neumann']
|
||||
n = self.n
|
||||
if(self.dim == 2):
|
||||
G2 = sp.kron(ddxCellGrad(n[1], BC), speye(n[0]))
|
||||
elif(self.dim == 3):
|
||||
G2 = kron3(speye(n[2]), ddxCellGrad(n[1], BC), speye(n[0]))
|
||||
# Compute areas of cell faces & volumes
|
||||
S = self.r(self.area, 'F','Fy', 'V')
|
||||
V = self.vol
|
||||
self._cellGrady = sdiag(S)*G2*sdiag(1/V)
|
||||
return self._cellGrady
|
||||
return locals()
|
||||
cellGrady = property(**cellGrady())
|
||||
|
||||
|
||||
|
||||
def cellGradz():
|
||||
doc = "Cell centered Gradient in the x dimension. Has neumann boundary conditions."
|
||||
def fget(self):
|
||||
if self.dim < 3:
|
||||
return None
|
||||
if getattr(self, '_cellGradz', None) is None:
|
||||
BC = ['neumann', 'neumann']
|
||||
n = self.n
|
||||
G3 = kron3(ddxCellGrad(n[2], BC), speye(n[1]), speye(n[0]))
|
||||
# Compute areas of cell faces & volumes
|
||||
S = self.r(self.area, 'F','Fz', 'V')
|
||||
V = self.vol
|
||||
self._cellGradz = sdiag(S)*G3*sdiag(1/V)
|
||||
return self._cellGradz
|
||||
return locals()
|
||||
cellGradz = property(**cellGradz())
|
||||
|
||||
|
||||
def edgeCurl():
|
||||
doc = "Construct the 3D curl operator."
|
||||
|
||||
|
||||
@@ -81,9 +81,9 @@ class InnerProducts(object):
|
||||
def getFaceInnerProduct(self, mu=None, returnP=False):
|
||||
"""Wrapper function,
|
||||
|
||||
:py:func:`SimPEG.InnerProducts.getEdgeInnerProduct`
|
||||
:py:func:`SimPEG.mesh.InnerProducts.InnerProducts.getEdgeInnerProduct`
|
||||
|
||||
:py:func:`SimPEG.InnerProducts.getEdgeInnerProduct2D`
|
||||
:py:func:`SimPEG.mesh.InnerProducts.InnerProducts.getEdgeInnerProduct2D`
|
||||
"""
|
||||
if self.dim == 2:
|
||||
return getFaceInnerProduct2D(self, mu, returnP)
|
||||
@@ -93,9 +93,9 @@ class InnerProducts(object):
|
||||
def getEdgeInnerProduct(self, sigma=None, returnP=False):
|
||||
"""Wrapper function,
|
||||
|
||||
:py:func:`SimPEG.InnerProducts.getFaceInnerProduct`
|
||||
:py:func:`SimPEG.mesh.InnerProducts.InnerProducts.getFaceInnerProduct`
|
||||
|
||||
:py:func:`SimPEG.InnerProducts.getFaceInnerProduct2D`
|
||||
:py:func:`SimPEG.mesh.InnerProducts.InnerProducts.getFaceInnerProduct2D`
|
||||
"""
|
||||
if self.dim == 2:
|
||||
return getEdgeInnerProduct2D(self, sigma, returnP)
|
||||
|
||||
@@ -38,7 +38,7 @@ class LogicallyOrthogonalMesh(BaseMesh, DiffOperators, InnerProducts, LomView):
|
||||
# Save nodes to private variable _gridN as vectors
|
||||
self._gridN = np.ones((nodes[0].size, self.dim))
|
||||
for i, node_i in enumerate(nodes):
|
||||
self._gridN[:, i] = mkvc(node_i)
|
||||
self._gridN[:, i] = mkvc(node_i.astype(float))
|
||||
|
||||
def gridCC():
|
||||
doc = "Cell-centered grid."
|
||||
|
||||
+104
-10
@@ -1,9 +1,10 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from BaseMesh import BaseMesh
|
||||
from TensorView import TensorView
|
||||
from DiffOperators import DiffOperators
|
||||
from InnerProducts import InnerProducts
|
||||
from SimPEG.utils import ndgrid, mkvc
|
||||
from SimPEG.utils import ndgrid, mkvc, spzeros, interpmat
|
||||
|
||||
|
||||
class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
@@ -38,7 +39,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i)
|
||||
|
||||
# Ensure h contains 1D vectors
|
||||
self._h = [mkvc(x) for x in h]
|
||||
self._h = [mkvc(x.astype(float)) for x in h]
|
||||
|
||||
def __str__(self):
|
||||
outStr = ' ---- {0:d}-D TensorMesh ---- '.format(self.dim)
|
||||
@@ -156,7 +157,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridCC is None:
|
||||
self._gridCC = ndgrid([x for x in [self.vectorCCx, self.vectorCCy, self.vectorCCz] if not x is None])
|
||||
self._gridCC = ndgrid(self.getTensor('CC'))
|
||||
return self._gridCC
|
||||
return locals()
|
||||
_gridCC = None # Store grid by default
|
||||
@@ -167,7 +168,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridN is None:
|
||||
self._gridN = ndgrid([x for x in [self.vectorNx, self.vectorNy, self.vectorNz] if not x is None])
|
||||
self._gridN = ndgrid(self.getTensor('N'))
|
||||
return self._gridN
|
||||
return locals()
|
||||
_gridN = None # Store grid by default
|
||||
@@ -178,7 +179,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridFx is None:
|
||||
self._gridFx = ndgrid([x for x in [self.vectorNx, self.vectorCCy, self.vectorCCz] if not x is None])
|
||||
self._gridFx = ndgrid(self.getTensor('Fx'))
|
||||
return self._gridFx
|
||||
return locals()
|
||||
_gridFx = None # Store grid by default
|
||||
@@ -189,7 +190,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridFy is None and self.dim > 1:
|
||||
self._gridFy = ndgrid([x for x in [self.vectorCCx, self.vectorNy, self.vectorCCz] if not x is None])
|
||||
self._gridFy = ndgrid(self.getTensor('Fy'))
|
||||
return self._gridFy
|
||||
return locals()
|
||||
_gridFy = None # Store grid by default
|
||||
@@ -200,7 +201,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridFz is None and self.dim > 2:
|
||||
self._gridFz = ndgrid([x for x in [self.vectorCCx, self.vectorCCy, self.vectorNz] if not x is None])
|
||||
self._gridFz = ndgrid(self.getTensor('Fz'))
|
||||
return self._gridFz
|
||||
return locals()
|
||||
_gridFz = None # Store grid by default
|
||||
@@ -211,7 +212,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridEx is None:
|
||||
self._gridEx = ndgrid([x for x in [self.vectorCCx, self.vectorNy, self.vectorNz] if not x is None])
|
||||
self._gridEx = ndgrid(self.getTensor('Ex'))
|
||||
return self._gridEx
|
||||
return locals()
|
||||
_gridEx = None # Store grid by default
|
||||
@@ -222,7 +223,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridEy is None and self.dim > 1:
|
||||
self._gridEy = ndgrid([x for x in [self.vectorNx, self.vectorCCy, self.vectorNz] if not x is None])
|
||||
self._gridEy = ndgrid(self.getTensor('Ey'))
|
||||
return self._gridEy
|
||||
return locals()
|
||||
_gridEy = None # Store grid by default
|
||||
@@ -233,7 +234,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
|
||||
def fget(self):
|
||||
if self._gridEz is None and self.dim > 2:
|
||||
self._gridEz = ndgrid([x for x in [self.vectorNx, self.vectorNy, self.vectorCCz] if not x is None])
|
||||
self._gridEz = ndgrid(self.getTensor('Ez'))
|
||||
return self._gridEz
|
||||
return locals()
|
||||
_gridEz = None # Store grid by default
|
||||
@@ -312,6 +313,98 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts):
|
||||
_edge = None
|
||||
edge = property(**edge())
|
||||
|
||||
# --------------- Methods ---------------------
|
||||
|
||||
def getTensor(self, locType):
|
||||
""" Returns a tensor list.
|
||||
|
||||
:param str locType: What tensor (see below)
|
||||
:rtype: list
|
||||
:return: list of the tensors that make up the mesh.
|
||||
|
||||
locType can be::
|
||||
|
||||
'Ex' -> x-component of field defined on edges
|
||||
'Ey' -> y-component of field defined on edges
|
||||
'Ez' -> z-component of field defined on edges
|
||||
'Fx' -> x-component of field defined on faces
|
||||
'Fy' -> y-component of field defined on faces
|
||||
'Fz' -> z-component of field defined on faces
|
||||
'N' -> scalar field defined on nodes
|
||||
'CC' -> scalar field defined on cell centers
|
||||
"""
|
||||
|
||||
if locType is 'Fx':
|
||||
ten = [self.vectorNx , self.vectorCCy, self.vectorCCz]
|
||||
elif locType is 'Fy':
|
||||
ten = [self.vectorCCx, self.vectorNy , self.vectorCCz]
|
||||
elif locType is 'Fz':
|
||||
ten = [self.vectorCCx, self.vectorCCy, self.vectorNz ]
|
||||
elif locType is 'Ex':
|
||||
ten = [self.vectorCCx, self.vectorNy , self.vectorNz ]
|
||||
elif locType is 'Ey':
|
||||
ten = [self.vectorNx , self.vectorCCy, self.vectorNz ]
|
||||
elif locType is 'Ez':
|
||||
ten = [self.vectorNx , self.vectorNy , self.vectorCCz]
|
||||
elif locType is 'CC':
|
||||
ten = [self.vectorCCx, self.vectorCCy, self.vectorCCz]
|
||||
elif locType is 'N':
|
||||
ten = [self.vectorNx , self.vectorNy , self.vectorNz ]
|
||||
|
||||
return [t for t in ten if t is not None]
|
||||
|
||||
|
||||
def isInside(self, pts):
|
||||
"""
|
||||
Determines if a set of points are inside a mesh.
|
||||
|
||||
:param numpy.ndarray pts: Location of points to test
|
||||
:rtype numpy.ndarray
|
||||
:return inside, numpy array of booleans
|
||||
"""
|
||||
|
||||
pts = np.atleast_2d(pts)
|
||||
inside = (pts[:,0] >= self.vectorNx.min()) & (pts[:,0] <= self.vectorNx.max())
|
||||
if self.dim > 1:
|
||||
inside = inside & ((pts[:,1] >= self.vectorNy.min()) & (pts[:,1] <= self.vectorNy.max()))
|
||||
if self.dim > 2:
|
||||
inside = inside & ((pts[:,2] >= self.vectorNz.min()) & (pts[:,2] <= self.vectorNz.max()))
|
||||
return inside
|
||||
|
||||
def getInterpolationMat(self, loc, locType):
|
||||
""" Produces interpolation matrix
|
||||
|
||||
:param numpy.ndarray loc: Location of points to interpolate to
|
||||
:param str locType: What to interpolate (see below)
|
||||
:rtype: scipy.sparse.csr.csr_matrix
|
||||
:return: M, the interpolation matrix
|
||||
|
||||
locType can be::
|
||||
|
||||
'Ex' -> x-component of field defined on edges
|
||||
'Ey' -> y-component of field defined on edges
|
||||
'Ez' -> z-component of field defined on edges
|
||||
'Fx' -> x-component of field defined on faces
|
||||
'Fy' -> y-component of field defined on faces
|
||||
'Fz' -> z-component of field defined on faces
|
||||
'N' -> scalar field defined on nodes
|
||||
'CC' -> scalar field defined on cell centers
|
||||
"""
|
||||
|
||||
loc = np.atleast_2d(loc)
|
||||
assert np.all(self.isInside(loc)), "Points outside of mesh"
|
||||
|
||||
ind = 0 if 'x' in locType else 1 if 'y' in locType else 2 if 'z' in locType else -1
|
||||
if locType in ['Fx','Fy','Fz','Ex','Ey','Ez'] and self.dim >= ind:
|
||||
nF_nE = self.nF if 'F' in locType else self.nE
|
||||
components = [spzeros(loc.shape[0], n) for n in nF_nE]
|
||||
components[ind] = interpmat(loc, *self.getTensor(locType))
|
||||
Q = sp.hstack(components)
|
||||
elif locType in ['CC', 'N']:
|
||||
Q = interpmat(loc, *self.getTensor(locType))
|
||||
else:
|
||||
raise NotImplementedError('getInterpolationMat: locType=='+locType+' and mesh.dim=='+str(self.dim))
|
||||
return Q
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Welcome to tensor mesh!')
|
||||
@@ -327,5 +420,6 @@ if __name__ == '__main__':
|
||||
h = h[:testDim]
|
||||
|
||||
M = TensorMesh(h)
|
||||
print M
|
||||
|
||||
xn = M.plotGrid()
|
||||
|
||||
@@ -267,6 +267,9 @@ class TensorView(object):
|
||||
if faces:
|
||||
ax.plot(xs1[:, 0], xs1[:, 1], 'g>')
|
||||
ax.plot(xs2[:, 0], xs2[:, 1], 'g^')
|
||||
if edges:
|
||||
ax.plot(self.gridEx[:, 0], self.gridEx[:, 1], 'c>')
|
||||
ax.plot(self.gridEy[:, 0], self.gridEy[:, 1], 'c^')
|
||||
|
||||
# Plot the grid lines
|
||||
if lines:
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
from SimPEG.utils import sdiag
|
||||
import numpy as np
|
||||
|
||||
class Regularization(object):
|
||||
"""docstring for Regularization"""
|
||||
|
||||
@property
|
||||
def mref(self):
|
||||
if getattr(self, '_mref', None) is None:
|
||||
self._mref = np.zeros(self.mesh.nC);
|
||||
return self._mref
|
||||
@mref.setter
|
||||
def mref(self, value):
|
||||
self._mref = value
|
||||
|
||||
@property
|
||||
def Ws(self):
|
||||
if getattr(self,'_Ws', None) is None:
|
||||
self._Ws = sdiag(self.mesh.vol)
|
||||
return self._Ws
|
||||
|
||||
@property
|
||||
def Wx(self):
|
||||
if getattr(self, '_Wx', None) is None:
|
||||
a = self.mesh.r(self.mesh.area,'F','Fx','V')
|
||||
self._Wx = sdiag(a)*self.mesh.cellGradx
|
||||
return self._Wx
|
||||
|
||||
@property
|
||||
def Wy(self):
|
||||
if getattr(self, '_Wy', None) is None:
|
||||
a = self.mesh.r(self.mesh.area,'F','Fy','V')
|
||||
self._Wy = sdiag(a)*self.mesh.cellGrady
|
||||
return self._Wy
|
||||
|
||||
@property
|
||||
def Wz(self):
|
||||
if getattr(self, '_Wz', None) is None:
|
||||
a = self.mesh.r(self.mesh.area,'F','Fz','V')
|
||||
self._Wz = sdiag(a)*self.mesh.cellGradz
|
||||
return self._Wz
|
||||
|
||||
|
||||
|
||||
def __init__(self, mesh):
|
||||
self.mesh = mesh
|
||||
self._Wx = None
|
||||
self._Wy = None
|
||||
self._Wz = None
|
||||
self.alpha_s = 1e-6
|
||||
self.alpha_x = 1
|
||||
self.alpha_y = 1
|
||||
self.alpha_z = 1
|
||||
|
||||
def pnorm(self, r):
|
||||
return 0.5*r.dot(r)
|
||||
|
||||
def modelObj(self, m):
|
||||
mresid = m - self.mref
|
||||
|
||||
mobj = self.alpha_s * self.pnorm( self.Ws * mresid )
|
||||
|
||||
mobj += self.alpha_x * self.pnorm( self.Wx * mresid )
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobj += self.alpha_y * self.pnorm( self.Wy * mresid )
|
||||
if self.mesh.dim > 2:
|
||||
mobj += self.alpha_z * self.pnorm( self.Wz * mresid )
|
||||
|
||||
return mobj
|
||||
|
||||
def modelObjDeriv(self, m):
|
||||
"""
|
||||
|
||||
In 1D:
|
||||
|
||||
.. math::
|
||||
|
||||
m_{\\text{obj}} = {1 \over 2}\\alpha_s \left\| W_s (m- m_{\\text{ref}})\\right\|^2_2
|
||||
+ {1 \over 2}\\alpha_x \left\| W_x (m- m_{\\text{ref}})\\right\|^2_2
|
||||
|
||||
\\frac{ \partial m_{\\text{obj}} }{\partial m} =
|
||||
\\alpha_s W_s^{\\top} W_s (m - m_{\\text{ref}}) +
|
||||
\\alpha_x W_x^{\\top} W_x (m - m_{\\text{ref}})
|
||||
|
||||
|
||||
\\frac{ \partial^2 m_{\\text{obj}} }{\partial m^2} =
|
||||
\\alpha_s W_s^{\\top} W_s +
|
||||
\\alpha_x W_x^{\\top} W_x
|
||||
|
||||
"""
|
||||
|
||||
mresid = m - self.mref
|
||||
|
||||
mobjDeriv = self.alpha_s * self.Ws.T * ( self.Ws * mresid)
|
||||
|
||||
mobjDeriv = mobjDeriv + self.alpha_x * self.Wx.T * ( self.Wx * mresid)
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobjDeriv = mobjDeriv + self.alpha_y * self.Wy.T * ( self.Wy * mresid)
|
||||
if self.mesh.dim > 2:
|
||||
mobjDeriv = mobjDeriv + self.alpha_z * self.Wz.T * ( self.Wz * mresid)
|
||||
|
||||
return mobjDeriv
|
||||
|
||||
|
||||
def modelObj2Deriv(self, m):
|
||||
mresid = m - self.mref
|
||||
|
||||
mobj2Deriv = self.alpha_s * self.Ws.T * self.Ws
|
||||
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_x * self.Wx.T * self.Wx
|
||||
|
||||
if self.mesh.dim > 1:
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_y * self.Wy.T * self.Wy
|
||||
if self.mesh.dim > 2:
|
||||
mobj2Deriv = mobj2Deriv + self.alpha_z * self.Wz.T * self.Wz
|
||||
|
||||
return mobj2Deriv
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from Regularization import Regularization
|
||||
@@ -0,0 +1,824 @@
|
||||
"""
|
||||
A TestRunner for use with the Python unit testing framework. It
|
||||
generates a HTML report to show the result at a glance.
|
||||
|
||||
The simplest way to use this is to invoke its main method. E.g.
|
||||
|
||||
import unittest
|
||||
import HTMLTestRunner
|
||||
|
||||
... define your tests ...
|
||||
|
||||
if __name__ == '__main__':
|
||||
HTMLTestRunner.main()
|
||||
|
||||
|
||||
For more customization options, instantiates a HTMLTestRunner object.
|
||||
HTMLTestRunner is a counterpart to unittest's TextTestRunner. E.g.
|
||||
|
||||
# output to a file
|
||||
fp = file('my_report.html', 'wb')
|
||||
runner = HTMLTestRunner.HTMLTestRunner(
|
||||
stream=fp,
|
||||
title='My unit test',
|
||||
description='This demonstrates the report output by HTMLTestRunner.'
|
||||
)
|
||||
|
||||
# Use an external stylesheet.
|
||||
# See the Template_mixin class for more customizable options
|
||||
runner.STYLESHEET_TMPL = '<link rel="stylesheet" href="my_stylesheet.css" type="text/css">'
|
||||
|
||||
# run the test
|
||||
runner.run(my_test_suite)
|
||||
|
||||
|
||||
------------------------------------------------------------------------
|
||||
Copyright (c) 2004-2007, Wai Yip Tung
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name Wai Yip Tung nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
||||
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
|
||||
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
|
||||
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
# URL: http://tungwaiyip.info/software/HTMLTestRunner.html
|
||||
|
||||
__author__ = "Wai Yip Tung"
|
||||
__version__ = "0.8.2"
|
||||
|
||||
|
||||
"""
|
||||
Change History
|
||||
|
||||
Version 0.8.2
|
||||
* Show output inline instead of popup window (Viorel Lupu).
|
||||
|
||||
Version in 0.8.1
|
||||
* Validated XHTML (Wolfgang Borgert).
|
||||
* Added description of test classes and test cases.
|
||||
|
||||
Version in 0.8.0
|
||||
* Define Template_mixin class for customization.
|
||||
* Workaround a IE 6 bug that it does not treat <script> block as CDATA.
|
||||
|
||||
Version in 0.7.1
|
||||
* Back port to Python 2.3 (Frank Horowitz).
|
||||
* Fix missing scroll bars in detail log (Podi).
|
||||
"""
|
||||
|
||||
# TODO: color stderr
|
||||
# TODO: simplify javascript using ,ore than 1 class in the class attribute?
|
||||
|
||||
import datetime
|
||||
import StringIO
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from xml.sax import saxutils
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# The redirectors below are used to capture output during testing. Output
|
||||
# sent to sys.stdout and sys.stderr are automatically captured. However
|
||||
# in some cases sys.stdout is already cached before HTMLTestRunner is
|
||||
# invoked (e.g. calling logging.basicConfig). In order to capture those
|
||||
# output, use the redirectors for the cached stream.
|
||||
#
|
||||
# e.g.
|
||||
# >>> logging.basicConfig(stream=HTMLTestRunner.stdout_redirector)
|
||||
# >>>
|
||||
|
||||
class OutputRedirector(object):
|
||||
""" Wrapper to redirect stdout or stderr """
|
||||
def __init__(self, fp):
|
||||
self.fp = fp
|
||||
|
||||
def write(self, s):
|
||||
self.fp.write(s)
|
||||
|
||||
def writelines(self, lines):
|
||||
self.fp.writelines(lines)
|
||||
|
||||
def flush(self):
|
||||
self.fp.flush()
|
||||
|
||||
stdout_redirector = OutputRedirector(sys.stdout)
|
||||
stderr_redirector = OutputRedirector(sys.stderr)
|
||||
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Template
|
||||
|
||||
class Template_mixin(object):
|
||||
"""
|
||||
Define a HTML template for report customerization and generation.
|
||||
|
||||
Overall structure of an HTML report
|
||||
|
||||
HTML
|
||||
+------------------------+
|
||||
|<html> |
|
||||
| <head> |
|
||||
| |
|
||||
| STYLESHEET |
|
||||
| +----------------+ |
|
||||
| | | |
|
||||
| +----------------+ |
|
||||
| |
|
||||
| </head> |
|
||||
| |
|
||||
| <body> |
|
||||
| |
|
||||
| HEADING |
|
||||
| +----------------+ |
|
||||
| | | |
|
||||
| +----------------+ |
|
||||
| |
|
||||
| REPORT |
|
||||
| +----------------+ |
|
||||
| | | |
|
||||
| +----------------+ |
|
||||
| |
|
||||
| ENDING |
|
||||
| +----------------+ |
|
||||
| | | |
|
||||
| +----------------+ |
|
||||
| |
|
||||
| </body> |
|
||||
|</html> |
|
||||
+------------------------+
|
||||
"""
|
||||
|
||||
STATUS = {
|
||||
0: 'pass',
|
||||
1: 'fail',
|
||||
2: 'error',
|
||||
}
|
||||
|
||||
DEFAULT_TITLE = 'Unit Test Report'
|
||||
DEFAULT_DESCRIPTION = ''
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# HTML Template
|
||||
|
||||
HTML_TMPL = r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<title>%(title)s</title>
|
||||
<meta name="generator" content="%(generator)s"/>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/>
|
||||
%(stylesheet)s
|
||||
</head>
|
||||
<body>
|
||||
<script language="javascript" type="text/javascript"><!--
|
||||
output_list = Array();
|
||||
|
||||
/* level - 0:Summary; 1:Failed; 2:All */
|
||||
function showCase(level) {
|
||||
trs = document.getElementsByTagName("tr");
|
||||
for (var i = 0; i < trs.length; i++) {
|
||||
tr = trs[i];
|
||||
id = tr.id;
|
||||
if (id.substr(0,2) == 'ft') {
|
||||
if (level < 1) {
|
||||
tr.className = 'hiddenRow';
|
||||
}
|
||||
else {
|
||||
tr.className = '';
|
||||
}
|
||||
}
|
||||
if (id.substr(0,2) == 'pt') {
|
||||
if (level > 1) {
|
||||
tr.className = '';
|
||||
}
|
||||
else {
|
||||
tr.className = 'hiddenRow';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function showClassDetail(cid, count) {
|
||||
var id_list = Array(count);
|
||||
var toHide = 1;
|
||||
for (var i = 0; i < count; i++) {
|
||||
tid0 = 't' + cid.substr(1) + '.' + (i+1);
|
||||
tid = 'f' + tid0;
|
||||
tr = document.getElementById(tid);
|
||||
if (!tr) {
|
||||
tid = 'p' + tid0;
|
||||
tr = document.getElementById(tid);
|
||||
}
|
||||
id_list[i] = tid;
|
||||
if (tr.className) {
|
||||
toHide = 0;
|
||||
}
|
||||
}
|
||||
for (var i = 0; i < count; i++) {
|
||||
tid = id_list[i];
|
||||
if (toHide) {
|
||||
document.getElementById('div_'+tid).style.display = 'none'
|
||||
document.getElementById(tid).className = 'hiddenRow';
|
||||
}
|
||||
else {
|
||||
document.getElementById(tid).className = '';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function showTestDetail(div_id){
|
||||
var details_div = document.getElementById(div_id)
|
||||
var displayState = details_div.style.display
|
||||
// alert(displayState)
|
||||
if (displayState != 'block' ) {
|
||||
displayState = 'block'
|
||||
details_div.style.display = 'block'
|
||||
}
|
||||
else {
|
||||
details_div.style.display = 'none'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function html_escape(s) {
|
||||
s = s.replace(/&/g,'&');
|
||||
s = s.replace(/</g,'<');
|
||||
s = s.replace(/>/g,'>');
|
||||
return s;
|
||||
}
|
||||
|
||||
/* obsoleted by detail in <div>
|
||||
function showOutput(id, name) {
|
||||
var w = window.open("", //url
|
||||
name,
|
||||
"resizable,scrollbars,status,width=800,height=450");
|
||||
d = w.document;
|
||||
d.write("<pre>");
|
||||
d.write(html_escape(output_list[id]));
|
||||
d.write("\n");
|
||||
d.write("<a href='javascript:window.close()'>close</a>\n");
|
||||
d.write("</pre>\n");
|
||||
d.close();
|
||||
}
|
||||
*/
|
||||
--></script>
|
||||
|
||||
%(heading)s
|
||||
%(report)s
|
||||
%(ending)s
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
# variables: (title, generator, stylesheet, heading, report, ending)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Stylesheet
|
||||
#
|
||||
# alternatively use a <link> for external style sheet, e.g.
|
||||
# <link rel="stylesheet" href="$url" type="text/css">
|
||||
|
||||
STYLESHEET_TMPL = """
|
||||
<style type="text/css" media="screen">
|
||||
body { font-family: verdana, arial, helvetica, sans-serif; font-size: 80%; }
|
||||
table { font-size: 100%; }
|
||||
pre { }
|
||||
|
||||
/* -- heading ---------------------------------------------------------------------- */
|
||||
h1 {
|
||||
font-size: 16pt;
|
||||
color: gray;
|
||||
}
|
||||
.heading {
|
||||
margin-top: 0ex;
|
||||
margin-bottom: 1ex;
|
||||
}
|
||||
|
||||
.heading .attribute {
|
||||
margin-top: 1ex;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.heading .description {
|
||||
margin-top: 4ex;
|
||||
margin-bottom: 6ex;
|
||||
}
|
||||
|
||||
/* -- css div popup ------------------------------------------------------------------------ */
|
||||
a.popup_link {
|
||||
}
|
||||
|
||||
a.popup_link:hover {
|
||||
color: red;
|
||||
}
|
||||
|
||||
.popup_window {
|
||||
display: none;
|
||||
position: relative;
|
||||
left: 0px;
|
||||
top: 0px;
|
||||
/*border: solid #627173 1px; */
|
||||
padding: 10px;
|
||||
background-color: #E6E6D6;
|
||||
font-family: "Lucida Console", "Courier New", Courier, monospace;
|
||||
text-align: left;
|
||||
font-size: 8pt;
|
||||
width: 500px;
|
||||
}
|
||||
|
||||
}
|
||||
/* -- report ------------------------------------------------------------------------ */
|
||||
#show_detail_line {
|
||||
margin-top: 3ex;
|
||||
margin-bottom: 1ex;
|
||||
}
|
||||
#result_table {
|
||||
width: 80%;
|
||||
border-collapse: collapse;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
#header_row {
|
||||
font-weight: bold;
|
||||
color: white;
|
||||
background-color: #777;
|
||||
}
|
||||
#result_table td {
|
||||
border: 1px solid #777;
|
||||
padding: 2px;
|
||||
}
|
||||
#total_row { font-weight: bold; }
|
||||
.passClass { background-color: #6c6; }
|
||||
.failClass { background-color: #c60; }
|
||||
.errorClass { background-color: #c00; }
|
||||
.passCase { color: #6c6; }
|
||||
.failCase { color: #c60; font-weight: bold; }
|
||||
.errorCase { color: #c00; font-weight: bold; }
|
||||
.hiddenRow { display: none; }
|
||||
.testcase { margin-left: 2em; }
|
||||
|
||||
|
||||
/* -- ending ---------------------------------------------------------------------- */
|
||||
#ending {
|
||||
}
|
||||
|
||||
</style>
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Heading
|
||||
#
|
||||
|
||||
HEADING_TMPL = """<div class='heading'>
|
||||
<h1>%(title)s</h1>
|
||||
%(parameters)s
|
||||
<p class='description'>%(description)s</p>
|
||||
</div>
|
||||
|
||||
""" # variables: (title, parameters, description)
|
||||
|
||||
HEADING_ATTRIBUTE_TMPL = """<p class='attribute'><strong>%(name)s:</strong> %(value)s</p>
|
||||
""" # variables: (name, value)
|
||||
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Report
|
||||
#
|
||||
|
||||
REPORT_TMPL = """
|
||||
<p id='show_detail_line'>Show
|
||||
<a href='javascript:showCase(0)'>Summary</a>
|
||||
<a href='javascript:showCase(1)'>Failed</a>
|
||||
<a href='javascript:showCase(2)'>All</a>
|
||||
</p>
|
||||
<table id='result_table'>
|
||||
<colgroup>
|
||||
<col align='left' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
</colgroup>
|
||||
<tr id='header_row'>
|
||||
<td>Test Group/Test case</td>
|
||||
<td>Count</td>
|
||||
<td>Pass</td>
|
||||
<td>Fail</td>
|
||||
<td>Error</td>
|
||||
<td>View</td>
|
||||
</tr>
|
||||
%(test_list)s
|
||||
<tr id='total_row'>
|
||||
<td>Total</td>
|
||||
<td>%(count)s</td>
|
||||
<td>%(Pass)s</td>
|
||||
<td>%(fail)s</td>
|
||||
<td>%(error)s</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
""" # variables: (test_list, count, Pass, fail, error)
|
||||
|
||||
REPORT_CLASS_TMPL = r"""
|
||||
<tr class='%(style)s'>
|
||||
<td>%(desc)s</td>
|
||||
<td>%(count)s</td>
|
||||
<td>%(Pass)s</td>
|
||||
<td>%(fail)s</td>
|
||||
<td>%(error)s</td>
|
||||
<td><a href="javascript:showClassDetail('%(cid)s',%(count)s)">Detail</a></td>
|
||||
</tr>
|
||||
""" # variables: (style, desc, count, Pass, fail, error, cid)
|
||||
|
||||
|
||||
REPORT_TEST_WITH_OUTPUT_TMPL = r"""
|
||||
<tr id='%(tid)s' class='%(Class)s'>
|
||||
<td class='%(style)s'><div class='testcase'>%(desc)s</div></td>
|
||||
<td colspan='5' align='center'>
|
||||
|
||||
<!--css div popup start-->
|
||||
<a class="popup_link" onfocus='this.blur();' href="javascript:showTestDetail('div_%(tid)s')" >
|
||||
%(status)s</a>
|
||||
|
||||
<div id='div_%(tid)s' class="popup_window">
|
||||
<div style='text-align: right; color:red;cursor:pointer'>
|
||||
<a onfocus='this.blur();' onclick="document.getElementById('div_%(tid)s').style.display = 'none' " >
|
||||
[x]</a>
|
||||
</div>
|
||||
<pre>
|
||||
%(script)s
|
||||
</pre>
|
||||
</div>
|
||||
<!--css div popup end-->
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
""" # variables: (tid, Class, style, desc, status)
|
||||
|
||||
|
||||
REPORT_TEST_NO_OUTPUT_TMPL = r"""
|
||||
<tr id='%(tid)s' class='%(Class)s'>
|
||||
<td class='%(style)s'><div class='testcase'>%(desc)s</div></td>
|
||||
<td colspan='5' align='center'>%(status)s</td>
|
||||
</tr>
|
||||
""" # variables: (tid, Class, style, desc, status)
|
||||
|
||||
|
||||
REPORT_TEST_OUTPUT_TMPL = r"""
|
||||
%(id)s: %(output)s
|
||||
""" # variables: (id, output)
|
||||
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# ENDING
|
||||
#
|
||||
|
||||
ENDING_TMPL = """<div id='ending'> </div>"""
|
||||
|
||||
# -------------------- The end of the Template class -------------------
|
||||
|
||||
|
||||
TestResult = unittest.TestResult
|
||||
|
||||
class _TestResult(TestResult):
|
||||
# note: _TestResult is a pure representation of results.
|
||||
# It lacks the output and reporting ability compares to unittest._TextTestResult.
|
||||
|
||||
def __init__(self, verbosity=1):
|
||||
TestResult.__init__(self)
|
||||
self.stdout0 = None
|
||||
self.stderr0 = None
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.error_count = 0
|
||||
self.verbosity = verbosity
|
||||
|
||||
# result is a list of result in 4 tuple
|
||||
# (
|
||||
# result code (0: success; 1: fail; 2: error),
|
||||
# TestCase object,
|
||||
# Test output (byte string),
|
||||
# stack trace,
|
||||
# )
|
||||
self.result = []
|
||||
|
||||
|
||||
def startTest(self, test):
|
||||
TestResult.startTest(self, test)
|
||||
# just one buffer for both stdout and stderr
|
||||
self.outputBuffer = StringIO.StringIO()
|
||||
stdout_redirector.fp = self.outputBuffer
|
||||
stderr_redirector.fp = self.outputBuffer
|
||||
self.stdout0 = sys.stdout
|
||||
self.stderr0 = sys.stderr
|
||||
sys.stdout = stdout_redirector
|
||||
sys.stderr = stderr_redirector
|
||||
|
||||
|
||||
def complete_output(self):
|
||||
"""
|
||||
Disconnect output redirection and return buffer.
|
||||
Safe to call multiple times.
|
||||
"""
|
||||
if self.stdout0:
|
||||
sys.stdout = self.stdout0
|
||||
sys.stderr = self.stderr0
|
||||
self.stdout0 = None
|
||||
self.stderr0 = None
|
||||
return self.outputBuffer.getvalue()
|
||||
|
||||
|
||||
def stopTest(self, test):
|
||||
# Usually one of addSuccess, addError or addFailure would have been called.
|
||||
# But there are some path in unittest that would bypass this.
|
||||
# We must disconnect stdout in stopTest(), which is guaranteed to be called.
|
||||
self.complete_output()
|
||||
|
||||
|
||||
def addSuccess(self, test):
|
||||
self.success_count += 1
|
||||
TestResult.addSuccess(self, test)
|
||||
output = self.complete_output()
|
||||
self.result.append((0, test, output, ''))
|
||||
if self.verbosity > 1:
|
||||
sys.stderr.write('ok ')
|
||||
sys.stderr.write(str(test))
|
||||
sys.stderr.write('\n')
|
||||
else:
|
||||
sys.stderr.write('.')
|
||||
|
||||
def addError(self, test, err):
|
||||
self.error_count += 1
|
||||
TestResult.addError(self, test, err)
|
||||
_, _exc_str = self.errors[-1]
|
||||
output = self.complete_output()
|
||||
self.result.append((2, test, output, _exc_str))
|
||||
if self.verbosity > 1:
|
||||
sys.stderr.write('E ')
|
||||
sys.stderr.write(str(test))
|
||||
sys.stderr.write('\n')
|
||||
else:
|
||||
sys.stderr.write('E')
|
||||
|
||||
def addFailure(self, test, err):
|
||||
self.failure_count += 1
|
||||
TestResult.addFailure(self, test, err)
|
||||
_, _exc_str = self.failures[-1]
|
||||
output = self.complete_output()
|
||||
self.result.append((1, test, output, _exc_str))
|
||||
if self.verbosity > 1:
|
||||
sys.stderr.write('F ')
|
||||
sys.stderr.write(str(test))
|
||||
sys.stderr.write('\n')
|
||||
else:
|
||||
sys.stderr.write('F')
|
||||
|
||||
|
||||
class HTMLTestRunner(Template_mixin):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, stream=sys.stdout, verbosity=1, title=None, description=None):
|
||||
self.stream = stream
|
||||
self.verbosity = verbosity
|
||||
if title is None:
|
||||
self.title = self.DEFAULT_TITLE
|
||||
else:
|
||||
self.title = title
|
||||
if description is None:
|
||||
self.description = self.DEFAULT_DESCRIPTION
|
||||
else:
|
||||
self.description = description
|
||||
|
||||
self.startTime = datetime.datetime.now()
|
||||
|
||||
|
||||
def run(self, test):
|
||||
"Run the given test case or test suite."
|
||||
result = _TestResult(self.verbosity)
|
||||
test(result)
|
||||
self.stopTime = datetime.datetime.now()
|
||||
self.generateReport(test, result)
|
||||
print >>sys.stderr, '\nTime Elapsed: %s' % (self.stopTime-self.startTime)
|
||||
return result
|
||||
|
||||
|
||||
def sortResult(self, result_list):
|
||||
# unittest does not seems to run in any particular order.
|
||||
# Here at least we want to group them together by class.
|
||||
rmap = {}
|
||||
classes = []
|
||||
for n,t,o,e in result_list:
|
||||
cls = t.__class__
|
||||
if not rmap.has_key(cls):
|
||||
rmap[cls] = []
|
||||
classes.append(cls)
|
||||
rmap[cls].append((n,t,o,e))
|
||||
r = [(cls, rmap[cls]) for cls in classes]
|
||||
return r
|
||||
|
||||
|
||||
def getReportAttributes(self, result):
|
||||
"""
|
||||
Return report attributes as a list of (name, value).
|
||||
Override this to add custom attributes.
|
||||
"""
|
||||
startTime = str(self.startTime)[:19]
|
||||
duration = str(self.stopTime - self.startTime)
|
||||
status = []
|
||||
if result.success_count: status.append('Pass %s' % result.success_count)
|
||||
if result.failure_count: status.append('Failure %s' % result.failure_count)
|
||||
if result.error_count: status.append('Error %s' % result.error_count )
|
||||
if status:
|
||||
status = ' '.join(status)
|
||||
else:
|
||||
status = 'none'
|
||||
return [
|
||||
('Start Time', startTime),
|
||||
('Duration', duration),
|
||||
('Status', status),
|
||||
]
|
||||
|
||||
|
||||
def generateReport(self, test, result):
|
||||
report_attrs = self.getReportAttributes(result)
|
||||
generator = 'HTMLTestRunner %s' % __version__
|
||||
stylesheet = self._generate_stylesheet()
|
||||
heading = self._generate_heading(report_attrs)
|
||||
report = self._generate_report(result)
|
||||
ending = self._generate_ending()
|
||||
output = self.HTML_TMPL % dict(
|
||||
title = saxutils.escape(self.title),
|
||||
generator = generator,
|
||||
stylesheet = stylesheet,
|
||||
heading = heading,
|
||||
report = report,
|
||||
ending = ending,
|
||||
)
|
||||
self.stream.write(output.encode('utf8'))
|
||||
|
||||
|
||||
def _generate_stylesheet(self):
|
||||
return self.STYLESHEET_TMPL
|
||||
|
||||
|
||||
def _generate_heading(self, report_attrs):
|
||||
a_lines = []
|
||||
for name, value in report_attrs:
|
||||
line = self.HEADING_ATTRIBUTE_TMPL % dict(
|
||||
name = saxutils.escape(name),
|
||||
value = saxutils.escape(value),
|
||||
)
|
||||
a_lines.append(line)
|
||||
heading = self.HEADING_TMPL % dict(
|
||||
title = saxutils.escape(self.title),
|
||||
parameters = ''.join(a_lines),
|
||||
description = saxutils.escape(self.description),
|
||||
)
|
||||
return heading
|
||||
|
||||
|
||||
def _generate_report(self, result):
|
||||
rows = []
|
||||
sortedResult = self.sortResult(result.result)
|
||||
for cid, (cls, cls_results) in enumerate(sortedResult):
|
||||
# subtotal for a class
|
||||
np = nf = ne = 0
|
||||
for n,t,o,e in cls_results:
|
||||
if n == 0: np += 1
|
||||
elif n == 1: nf += 1
|
||||
else: ne += 1
|
||||
|
||||
# format class description
|
||||
if cls.__module__ == "__main__":
|
||||
name = cls.__name__
|
||||
else:
|
||||
name = "%s.%s" % (cls.__module__, cls.__name__)
|
||||
doc = cls.__doc__ and cls.__doc__.split("\n")[0] or ""
|
||||
desc = doc and '%s: %s' % (name, doc) or name
|
||||
|
||||
row = self.REPORT_CLASS_TMPL % dict(
|
||||
style = ne > 0 and 'errorClass' or nf > 0 and 'failClass' or 'passClass',
|
||||
desc = desc,
|
||||
count = np+nf+ne,
|
||||
Pass = np,
|
||||
fail = nf,
|
||||
error = ne,
|
||||
cid = 'c%s' % (cid+1),
|
||||
)
|
||||
rows.append(row)
|
||||
|
||||
for tid, (n,t,o,e) in enumerate(cls_results):
|
||||
self._generate_report_test(rows, cid, tid, n, t, o, e)
|
||||
|
||||
report = self.REPORT_TMPL % dict(
|
||||
test_list = ''.join(rows),
|
||||
count = str(result.success_count+result.failure_count+result.error_count),
|
||||
Pass = str(result.success_count),
|
||||
fail = str(result.failure_count),
|
||||
error = str(result.error_count),
|
||||
)
|
||||
return report
|
||||
|
||||
|
||||
def _generate_report_test(self, rows, cid, tid, n, t, o, e):
|
||||
# e.g. 'pt1.1', 'ft1.1', etc
|
||||
has_output = bool(o or e)
|
||||
tid = (n == 0 and 'p' or 'f') + 't%s.%s' % (cid+1,tid+1)
|
||||
name = t.id().split('.')[-1]
|
||||
doc = t.shortDescription() or ""
|
||||
desc = doc and ('%s: %s' % (name, doc)) or name
|
||||
tmpl = has_output and self.REPORT_TEST_WITH_OUTPUT_TMPL or self.REPORT_TEST_NO_OUTPUT_TMPL
|
||||
|
||||
# o and e should be byte string because they are collected from stdout and stderr?
|
||||
if isinstance(o,str):
|
||||
# TODO: some problem with 'string_escape': it escape \n and mess up formating
|
||||
# uo = unicode(o.encode('string_escape'))
|
||||
uo = o.decode('latin-1')
|
||||
else:
|
||||
uo = o
|
||||
if isinstance(e,str):
|
||||
# TODO: some problem with 'string_escape': it escape \n and mess up formating
|
||||
# ue = unicode(e.encode('string_escape'))
|
||||
ue = e.decode('latin-1')
|
||||
else:
|
||||
ue = e
|
||||
|
||||
script = self.REPORT_TEST_OUTPUT_TMPL % dict(
|
||||
id = tid,
|
||||
output = saxutils.escape(uo+ue),
|
||||
)
|
||||
|
||||
row = tmpl % dict(
|
||||
tid = tid,
|
||||
Class = (n == 0 and 'hiddenRow' or 'none'),
|
||||
style = n == 2 and 'errorCase' or (n == 1 and 'failCase' or 'none'),
|
||||
desc = desc,
|
||||
script = script,
|
||||
status = self.STATUS[n],
|
||||
)
|
||||
rows.append(row)
|
||||
if not has_output:
|
||||
return
|
||||
|
||||
def _generate_ending(self):
|
||||
return self.ENDING_TMPL
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Facilities for running tests from the command line
|
||||
##############################################################################
|
||||
|
||||
# Note: Reuse unittest.TestProgram to launch test. In the future we may
|
||||
# build our own launcher to support more specific command line
|
||||
# parameters like test title, CSS, etc.
|
||||
class TestProgram(unittest.TestProgram):
|
||||
"""
|
||||
A variation of the unittest.TestProgram. Please refer to the base
|
||||
class for command line parameters.
|
||||
"""
|
||||
def runTests(self):
|
||||
# Pick HTMLTestRunner as the default test runner.
|
||||
# base class's testRunner parameter is not useful because it means
|
||||
# we have to instantiate HTMLTestRunner before we know self.verbosity.
|
||||
if self.testRunner is None:
|
||||
self.testRunner = HTMLTestRunner(verbosity=self.verbosity)
|
||||
unittest.TestProgram.runTests(self)
|
||||
|
||||
main = TestProgram
|
||||
|
||||
##############################################################################
|
||||
# Executing this module from the command line
|
||||
##############################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(module=None)
|
||||
+55
-10
@@ -1,12 +1,15 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pylab import norm
|
||||
from SimPEG.utils import mkvc
|
||||
from SimPEG.utils import mkvc, sdiag
|
||||
from SimPEG import utils
|
||||
from SimPEG.mesh import TensorMesh, LogicallyOrthogonalMesh
|
||||
import numpy as np
|
||||
import unittest
|
||||
import inspect
|
||||
|
||||
happiness = ['The test be workin!', 'You get a gold star!', 'Yay passed!', 'Happy little convergence test!', 'That was easy!', 'Testing is important.', 'You are awesome.', 'Go Test Go!', 'Once upon a time, a happy little test passed.', 'And then everyone was happy.']
|
||||
sadness = ['No gold star for you.','Try again soon.','Thankfully, persistence is a great substitute for talent.','It might be easier to call this a feature...','Coffee break?', 'Boooooooo :(', 'Testing is important. Do it again.']
|
||||
|
||||
class OrderTest(unittest.TestCase):
|
||||
"""
|
||||
@@ -67,8 +70,7 @@ class OrderTest(unittest.TestCase):
|
||||
|
||||
name = "Order Test"
|
||||
expectedOrders = 2. # This can be a list of orders, must be the same length as meshTypes
|
||||
_expectedOrder = 2.
|
||||
tolerance = 0.85
|
||||
tolerance = 0.85 # This can also be a list, must be the same length as meshTypes
|
||||
meshSizes = [4, 8, 16, 32]
|
||||
meshTypes = ['uniformTensorMesh']
|
||||
_meshType = meshTypes[0]
|
||||
@@ -124,6 +126,8 @@ class OrderTest(unittest.TestCase):
|
||||
|
||||
"""
|
||||
assert type(self.meshTypes) == list, 'meshTypes must be a list'
|
||||
if type(self.tolerance) is not list:
|
||||
self.tolerance = np.ones(len(self.meshTypes))*self.tolerance
|
||||
|
||||
# if we just provide one expected order, repeat it for each mesh type
|
||||
if type(self.expectedOrders) == float or type(self.expectedOrders) == int:
|
||||
@@ -134,6 +138,7 @@ class OrderTest(unittest.TestCase):
|
||||
|
||||
for ii_meshType, meshType in enumerate(self.meshTypes):
|
||||
self._meshType = meshType
|
||||
self._tolerance = self.tolerance[ii_meshType]
|
||||
self._expectedOrder = self.expectedOrders[ii_meshType]
|
||||
|
||||
order = []
|
||||
@@ -144,7 +149,7 @@ class OrderTest(unittest.TestCase):
|
||||
err = self.getError()
|
||||
if ii == 0:
|
||||
print ''
|
||||
print 'Testing convergence on ' + self.M._meshType + ' of: ' + self.name
|
||||
print self._meshType + ': ' + self.name
|
||||
print '_____________________________________________'
|
||||
print ' h | error | e(i-1)/e(i) | order'
|
||||
print '~~~~~~|~~~~~~~~~~~~~|~~~~~~~~~~~~~|~~~~~~~~~~'
|
||||
@@ -155,21 +160,28 @@ class OrderTest(unittest.TestCase):
|
||||
err_old = err
|
||||
max_h_old = max_h
|
||||
print '---------------------------------------------'
|
||||
passTest = np.mean(np.array(order)) > self.tolerance*self._expectedOrder
|
||||
passTest = np.mean(np.array(order)) > self._tolerance*self._expectedOrder
|
||||
if passTest:
|
||||
print ['The test be workin!', 'You get a gold star!', 'Yay passed!', 'Happy little convergence test!', 'That was easy!'][np.random.randint(5)]
|
||||
print happiness[np.random.randint(len(happiness))]
|
||||
else:
|
||||
print 'Failed to pass test on ' + self._meshType + '.'
|
||||
print sadness[np.random.randint(len(sadness))]
|
||||
print ''
|
||||
self.assertTrue(passTest)
|
||||
|
||||
def Rosenbrock(x):
|
||||
def Rosenbrock(x, return_g=True, return_H=True):
|
||||
"""Rosenbrock function for testing GaussNewton scheme"""
|
||||
|
||||
f = 100*(x[1]-x[0]**2)**2+(1-x[0])**2
|
||||
g = np.array([2*(200*x[0]**3-200*x[0]*x[1]+x[0]-1), 200*(x[1]-x[0]**2)])
|
||||
H = np.array([[-400*x[1]+1200*x[0]**2+2, -400*x[0]], [-400*x[0], 200]])
|
||||
return f, g, H
|
||||
|
||||
out = (f,)
|
||||
if return_g:
|
||||
out += (g,)
|
||||
if return_H:
|
||||
out += (H,)
|
||||
return out
|
||||
|
||||
def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None):
|
||||
"""
|
||||
@@ -186,6 +198,16 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None):
|
||||
:rtype: bool
|
||||
:return: did you pass the test?!
|
||||
|
||||
|
||||
.. plot::
|
||||
:include-source:
|
||||
|
||||
from SimPEG.tests import checkDerivative
|
||||
from SimPEG.utils import sdiag
|
||||
import numpy as np
|
||||
def simplePass(x):
|
||||
return np.sin(x), sdiag(np.cos(x))
|
||||
checkDerivative(simplePass, np.random.randn(5))
|
||||
"""
|
||||
|
||||
print "%s checkDerivative %s" % ('='*20, '='*20)
|
||||
@@ -206,7 +228,11 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None):
|
||||
for i in range(num):
|
||||
Jt = fctn(x0+t[i]*dx)
|
||||
E0[i] = l2norm(Jt[0]-Jc[0]) # 0th order Taylor
|
||||
E1[i] = l2norm(Jt[0]-Jc[0]-t[i]*Jc[1].dot(dx)) # 1st order Taylor
|
||||
if inspect.isfunction(Jc[1]):
|
||||
E1[i] = l2norm(Jt[0]-Jc[0]-t[i]*Jc[1](dx)) # 1st order Taylor
|
||||
else:
|
||||
# We assume it is a numpy.ndarray
|
||||
E1[i] = l2norm(Jt[0]-Jc[0]-t[i]*Jc[1].dot(dx)) # 1st order Taylor
|
||||
order0 = np.log10(E0[:-1]/E0[1:])
|
||||
order1 = np.log10(E1[:-1]/E1[1:])
|
||||
print "%d\t%1.2e\t%1.3e\t\t%1.3e\t\t%1.3f" % (i, t[i], E0[i], E1[i], np.nan if i == 0 else order1[i-1])
|
||||
@@ -222,9 +248,12 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None):
|
||||
passTest = belowTol or correctOrder
|
||||
|
||||
if passTest:
|
||||
print "%s PASS! %s\n" % ('='*25, '='*25)
|
||||
print "%s PASS! %s" % ('='*25, '='*25)
|
||||
print happiness[np.random.randint(len(happiness))]+'\n'
|
||||
else:
|
||||
print "%s\n%s FAIL! %s\n%s" % ('*'*57, '<'*25, '>'*25, '*'*57)
|
||||
print sadness[np.random.randint(len(sadness))]+'\n'
|
||||
|
||||
|
||||
if plotIt:
|
||||
plt.figure()
|
||||
@@ -238,3 +267,19 @@ def checkDerivative(fctn, x0, num=7, plotIt=True, dx=None):
|
||||
plt.show()
|
||||
|
||||
return passTest
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def simplePass(x):
|
||||
return np.sin(x), sdiag(np.cos(x))
|
||||
|
||||
def simpleFunction(x):
|
||||
return np.sin(x), lambda xi: sdiag(np.cos(x))*xi
|
||||
|
||||
def simpleFail(x):
|
||||
return np.sin(x), -sdiag(np.cos(x))
|
||||
|
||||
checkDerivative(simplePass, np.random.randn(5), plotIt=False)
|
||||
checkDerivative(simpleFunction, np.random.randn(5), plotIt=False)
|
||||
checkDerivative(simpleFail, np.random.randn(5), plotIt=False)
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
.. _api_TestResults:
|
||||
|
||||
.. raw:: html
|
||||
<style type="text/css" media="screen">
|
||||
body { font-family: verdana, arial, helvetica, sans-serif; font-size: 80%; }
|
||||
table { font-size: 100%; }
|
||||
pre { }
|
||||
|
||||
/* -- heading ---------------------------------------------------------------------- */
|
||||
h1 {
|
||||
font-size: 16pt;
|
||||
color: gray;
|
||||
}
|
||||
.heading {
|
||||
margin-top: 0ex;
|
||||
margin-bottom: 1ex;
|
||||
}
|
||||
|
||||
.heading .attribute {
|
||||
margin-top: 1ex;
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.heading .description {
|
||||
margin-top: 4ex;
|
||||
margin-bottom: 6ex;
|
||||
}
|
||||
|
||||
/* -- css div popup ------------------------------------------------------------------------ */
|
||||
a.popup_link {
|
||||
}
|
||||
|
||||
a.popup_link:hover {
|
||||
color: red;
|
||||
}
|
||||
|
||||
.popup_window {
|
||||
display: none;
|
||||
position: relative;
|
||||
left: 0px;
|
||||
top: 0px;
|
||||
/*border: solid #627173 1px; */
|
||||
padding: 10px;
|
||||
background-color: #E6E6D6;
|
||||
font-family: "Lucida Console", "Courier New", Courier, monospace;
|
||||
text-align: left;
|
||||
font-size: 8pt;
|
||||
width: 500px;
|
||||
}
|
||||
|
||||
}
|
||||
/* -- report ------------------------------------------------------------------------ */
|
||||
#show_detail_line {
|
||||
margin-top: 3ex;
|
||||
margin-bottom: 1ex;
|
||||
}
|
||||
#result_table {
|
||||
width: 80%;
|
||||
border-collapse: collapse;
|
||||
border: 1px solid #777;
|
||||
}
|
||||
#header_row {
|
||||
font-weight: bold;
|
||||
color: white;
|
||||
background-color: #777;
|
||||
}
|
||||
#result_table td {
|
||||
border: 1px solid #777;
|
||||
padding: 2px;
|
||||
}
|
||||
#total_row { font-weight: bold; }
|
||||
.passClass { background-color: #6c6; }
|
||||
.failClass { background-color: #c60; }
|
||||
.errorClass { background-color: #c00; }
|
||||
.passCase { color: #6c6; }
|
||||
.failCase { color: #c60; font-weight: bold; }
|
||||
.errorCase { color: #c00; font-weight: bold; }
|
||||
.hiddenRow { display: none; }
|
||||
.testcase { margin-left: 2em; }
|
||||
|
||||
|
||||
/* -- ending ---------------------------------------------------------------------- */
|
||||
#ending {
|
||||
}
|
||||
|
||||
</style>
|
||||
|
||||
<body>
|
||||
<script language="javascript" type="text/javascript"><!--
|
||||
output_list = Array();
|
||||
|
||||
/* level - 0:Summary; 1:Failed; 2:All */
|
||||
function showCase(level) {
|
||||
trs = document.getElementsByTagName("tr");
|
||||
for (var i = 0; i < trs.length; i++) {
|
||||
tr = trs[i];
|
||||
id = tr.id;
|
||||
if (id.substr(0,2) == 'ft') {
|
||||
if (level < 1) {
|
||||
tr.className = 'hiddenRow';
|
||||
}
|
||||
else {
|
||||
tr.className = '';
|
||||
}
|
||||
}
|
||||
if (id.substr(0,2) == 'pt') {
|
||||
if (level > 1) {
|
||||
tr.className = '';
|
||||
}
|
||||
else {
|
||||
tr.className = 'hiddenRow';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function showClassDetail(cid, count) {
|
||||
var id_list = Array(count);
|
||||
var toHide = 1;
|
||||
for (var i = 0; i < count; i++) {
|
||||
tid0 = 't' + cid.substr(1) + '.' + (i+1);
|
||||
tid = 'f' + tid0;
|
||||
tr = document.getElementById(tid);
|
||||
if (!tr) {
|
||||
tid = 'p' + tid0;
|
||||
tr = document.getElementById(tid);
|
||||
}
|
||||
id_list[i] = tid;
|
||||
if (tr.className) {
|
||||
toHide = 0;
|
||||
}
|
||||
}
|
||||
for (var i = 0; i < count; i++) {
|
||||
tid = id_list[i];
|
||||
if (toHide) {
|
||||
document.getElementById('div_'+tid).style.display = 'none'
|
||||
document.getElementById(tid).className = 'hiddenRow';
|
||||
}
|
||||
else {
|
||||
document.getElementById(tid).className = '';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function showTestDetail(div_id){
|
||||
var details_div = document.getElementById(div_id)
|
||||
var displayState = details_div.style.display
|
||||
// alert(displayState)
|
||||
if (displayState != 'block' ) {
|
||||
displayState = 'block'
|
||||
details_div.style.display = 'block'
|
||||
}
|
||||
else {
|
||||
details_div.style.display = 'none'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function html_escape(s) {
|
||||
s = s.replace(/&/g,'&');
|
||||
s = s.replace(/</g,'<');
|
||||
s = s.replace(/>/g,'>');
|
||||
return s;
|
||||
}
|
||||
|
||||
/* obsoleted by detail in <div>
|
||||
function showOutput(id, name) {
|
||||
var w = window.open("", //url
|
||||
name,
|
||||
"resizable,scrollbars,status,width=800,height=450");
|
||||
d = w.document;
|
||||
d.write("<pre>");
|
||||
d.write(html_escape(output_list[id]));
|
||||
d.write("\n");
|
||||
d.write("<a href='javascript:window.close()'>close</a>\n");
|
||||
d.write("</pre>\n");
|
||||
d.close();
|
||||
}
|
||||
*/
|
||||
--></script>
|
||||
|
||||
<div class='heading'>
|
||||
<h1>Test Report</h1>
|
||||
<p class='attribute'><strong>Start Time:</strong> 2013-11-05 15:24:44</p>
|
||||
<p class='attribute'><strong>Duration:</strong> 0:00:00.007500</p>
|
||||
<p class='attribute'><strong>Status:</strong> Pass 22</p>
|
||||
|
||||
<p class='description'>This demonstrates the report output by Prasanna.Yelsangikar.</p>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<p id='show_detail_line'>Show
|
||||
<a href='javascript:showCase(0)'>Summary</a>
|
||||
<a href='javascript:showCase(1)'>Failed</a>
|
||||
<a href='javascript:showCase(2)'>All</a>
|
||||
</p>
|
||||
<table id='result_table'>
|
||||
<colgroup>
|
||||
<col align='left' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
<col align='right' />
|
||||
</colgroup>
|
||||
<tr id='header_row'>
|
||||
<td>Test Group/Test case</td>
|
||||
<td>Count</td>
|
||||
<td>Pass</td>
|
||||
<td>Fail</td>
|
||||
<td>Error</td>
|
||||
<td>View</td>
|
||||
</tr>
|
||||
|
||||
<tr class='passClass'>
|
||||
<td>test_basemesh.TestBaseMesh</td>
|
||||
<td>11</td>
|
||||
<td>11</td>
|
||||
<td>0</td>
|
||||
<td>0</td>
|
||||
<td><a href="javascript:showClassDetail('c1',11)">Detail</a></td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.1' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_meshDimensions</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.2' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nc</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.3' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nc_xyz</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.4' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_ne</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.5' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nf</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.6' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_numbers</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.7' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_CC_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.8' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_E_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.9' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_E_V</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.10' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_F_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt1.11' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_F_V</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr class='passClass'>
|
||||
<td>test_basemesh.TestMeshNumbers2D</td>
|
||||
<td>11</td>
|
||||
<td>11</td>
|
||||
<td>0</td>
|
||||
<td>0</td>
|
||||
<td><a href="javascript:showClassDetail('c2',11)">Detail</a></td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.1' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_meshDimensions</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.2' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nc</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.3' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nc_xyz</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.4' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_ne</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.5' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_nf</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.6' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_numbers</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.7' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_CC_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.8' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_E_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.9' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_E_V</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.10' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_F_M</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='pt2.11' class='hiddenRow'>
|
||||
<td class='none'><div class='testcase'>test_mesh_r_F_V</div></td>
|
||||
<td colspan='5' align='center'>pass</td>
|
||||
</tr>
|
||||
|
||||
<tr id='total_row'>
|
||||
<td>Total</td>
|
||||
<td>22</td>
|
||||
<td>22</td>
|
||||
<td>0</td>
|
||||
<td>0</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@@ -1,11 +1,50 @@
|
||||
import os
|
||||
import glob
|
||||
import unittest
|
||||
import HTMLTestRunner
|
||||
|
||||
# This code will run all tests in directory named test_*.py
|
||||
|
||||
TITLE = 'Test Results'
|
||||
test_file_strings = glob.glob('test_*.py')
|
||||
module_strings = [str[0:len(str)-3] for str in test_file_strings]
|
||||
suites = [unittest.defaultTestLoader.loadTestsFromName(str) for str
|
||||
in module_strings]
|
||||
testSuite = unittest.TestSuite(suites)
|
||||
text_runner = unittest.TextTestRunner().run(testSuite)
|
||||
unittest.TextTestRunner(verbosity=2).run(testSuite)
|
||||
|
||||
|
||||
outfile = open("report.html", "w")
|
||||
runner = HTMLTestRunner.HTMLTestRunner(
|
||||
stream=outfile,
|
||||
title=TITLE,
|
||||
description='SimPEG Test Report was automatically generated.'
|
||||
)
|
||||
|
||||
runner.run(testSuite)
|
||||
outfile.close()
|
||||
|
||||
reader = open("report.html", "r")
|
||||
writer = open("../../docs/api_TestResults.rst", "w")
|
||||
|
||||
writer.write('.. _api_TestResults:\n\nTest Results\n============\n\n.. raw:: html\n\n')
|
||||
|
||||
go = False
|
||||
for line in reader:
|
||||
skip = False
|
||||
if line == '<style type="text/css" media="screen">\n':
|
||||
go = True
|
||||
elif line == "<div id='ending'> </div>\n":
|
||||
go = False
|
||||
elif line == '</head>\n':
|
||||
skip = True
|
||||
elif line == '<h1>'+TITLE+'</h1>\n':
|
||||
skip = True
|
||||
elif line == '<body>\n':
|
||||
skip = True
|
||||
if go and not skip:
|
||||
writer.write(' '+line)
|
||||
|
||||
writer.close()
|
||||
reader.close()
|
||||
os.remove("report.html")
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
import unittest
|
||||
from SimPEG import Solver
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.utils import sdiag
|
||||
import numpy as np
|
||||
import scipy.sparse as sparse
|
||||
|
||||
TOL = 1e-10
|
||||
numRHS = 5
|
||||
|
||||
|
||||
class TestSolver(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
h1 = np.ones(10)*100.
|
||||
h2 = np.ones(10)*100.
|
||||
h3 = np.ones(10)*100.
|
||||
|
||||
h = [h1,h2,h3]
|
||||
|
||||
M = TensorMesh(h)
|
||||
|
||||
D = M.faceDiv
|
||||
G = M.cellGrad
|
||||
Msig = M.getFaceMass()
|
||||
A = D*Msig*G
|
||||
A[0,0] *= 10 # remove the constant null space from the matrix
|
||||
|
||||
self.A = A
|
||||
self.M = M
|
||||
|
||||
def test_directFactored_1(self):
|
||||
solve = Solver(self.A, doDirect=True, flag=None, options={'factorize':True,'backend':'scipy'})
|
||||
e = np.ones(self.M.nC)
|
||||
rhs = self.A.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
|
||||
def test_directFactored_M(self):
|
||||
solve = Solver(self.A, doDirect=True, flag=None, options={'factorize':True,'backend':'scipy'})
|
||||
e = np.ones((self.M.nC,numRHS))
|
||||
rhs = self.A.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directSpsolve_1(self):
|
||||
solve = Solver(self.A, doDirect=True, flag=None, options={'factorize':False,'backend':'scipy'})
|
||||
e = np.ones(self.M.nC)
|
||||
rhs = self.A.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directSpsolve_M(self):
|
||||
solve = Solver(self.A, doDirect=True, flag=None, options={'factorize':False,'backend':'scipy'})
|
||||
e = np.ones((self.M.nC, numRHS))
|
||||
rhs = self.A.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directLower_1(self):
|
||||
AL = sparse.tril(self.A)
|
||||
solve = Solver(AL, doDirect=True, flag='L', options={})
|
||||
e = np.ones(self.M.nC)
|
||||
rhs = AL.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directLower_M(self):
|
||||
AL = sparse.tril(self.A)
|
||||
solve = Solver(AL, doDirect=True, flag='L', options={})
|
||||
e = np.ones((self.M.nC,numRHS))
|
||||
rhs = AL.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directUpper_1(self):
|
||||
AU = sparse.triu(self.A)
|
||||
solve = Solver(AU, doDirect=True, flag='U', options={})
|
||||
e = np.ones(self.M.nC)
|
||||
rhs = AU.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directUpper_M(self):
|
||||
AU = sparse.triu(self.A)
|
||||
solve = Solver(AU, doDirect=True, flag='U', options={})
|
||||
e = np.ones((self.M.nC,numRHS))
|
||||
rhs = AU.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directDiagonal_1(self):
|
||||
AD = sdiag(self.A.diagonal())
|
||||
solve = Solver(AD, doDirect=True, flag='D', options={})
|
||||
e = np.ones(self.M.nC)
|
||||
rhs = AD.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
def test_directDiagonal_M(self):
|
||||
AD = sdiag(self.A.diagonal())
|
||||
solve = Solver(AD, doDirect=True, flag='D', options={})
|
||||
e = np.ones((self.M.nC,numRHS))
|
||||
rhs = AD.dot(e)
|
||||
x = solve.solve(rhs)
|
||||
self.assertTrue(np.linalg.norm(e-x,np.inf) < TOL, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -3,9 +3,11 @@ import unittest
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.utils import ModelBuilder, sdiag
|
||||
from SimPEG.forward import Problem, SyntheticProblem
|
||||
from SimPEG.forward.DCProblem import DCProblem, DCutils
|
||||
from SimPEG.forward.DCProblem import *
|
||||
from TestUtils import checkDerivative
|
||||
from scipy.sparse.linalg import dsolve
|
||||
from SimPEG.regularization import Regularization
|
||||
from SimPEG import inverse
|
||||
|
||||
|
||||
class DCProblemTests(unittest.TestCase):
|
||||
@@ -34,7 +36,7 @@ class DCProblemTests(unittest.TestCase):
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
q, Q, rxmidloc = DCutils.genTxRxmat(nelec, spacelec, surfloc, elecini, mesh)
|
||||
q, Q, rxmidloc = genTxRxmat(nelec, spacelec, surfloc, elecini, mesh)
|
||||
P = Q.T
|
||||
|
||||
# Create some data
|
||||
@@ -52,22 +54,27 @@ class DCProblemTests(unittest.TestCase):
|
||||
problem.RHS = q
|
||||
problem.W = Wd
|
||||
problem.dobs = dobs
|
||||
problem.std = dobs*0 + 0.05
|
||||
|
||||
opt = inverse.InexactGaussNewton(maxIterLS=20, maxIter=10, tolF=1e-6, tolX=1e-6, tolG=1e-6, maxIterCG=6)
|
||||
reg = Regularization(mesh)
|
||||
inv = inverse.Inversion(problem, reg, opt, beta0=1e4)
|
||||
|
||||
self.inv = inv
|
||||
self.reg = reg
|
||||
self.p = problem
|
||||
self.mesh = mesh
|
||||
self.m0 = mSynth
|
||||
self.dobs = dobs
|
||||
|
||||
|
||||
def test_misfit(self):
|
||||
print 'SimPEG.forward.DCProblem: Testing Misfit'
|
||||
derChk = lambda m: [self.p.misfit(m), self.p.misfitDeriv(m)]
|
||||
derChk = lambda m: [self.p.dpred(m), lambda mx: self.p.J(self.m0, mx)]
|
||||
passed = checkDerivative(derChk, self.m0, plotIt=False)
|
||||
self.assertTrue(passed)
|
||||
|
||||
def test_adjoint(self):
|
||||
# Adjoint Test
|
||||
u = np.random.rand(self.mesh.nC)
|
||||
u = np.random.rand(self.mesh.nC*self.p.RHS.shape[1])
|
||||
v = np.random.rand(self.mesh.nC)
|
||||
w = np.random.rand(self.dobs.shape[0])
|
||||
wtJv = w.dot(self.p.J(self.m0, v, u=u))
|
||||
@@ -75,6 +82,13 @@ class DCProblemTests(unittest.TestCase):
|
||||
passed = (wtJv - vtJtw) < 1e-10
|
||||
self.assertTrue(passed)
|
||||
|
||||
def test_dataObj(self):
|
||||
derChk = lambda m: [self.inv.dataObj(m), self.inv.dataObjDeriv(m)]
|
||||
checkDerivative(derChk, self.m0, plotIt=False)
|
||||
|
||||
def test_modelObj(self):
|
||||
derChk = lambda m: [self.reg.modelObj(m), self.reg.modelObjDeriv(m)]
|
||||
checkDerivative(derChk, self.m0, plotIt=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import unittest
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.forward import Problem
|
||||
from SimPEG.regularization import Regularization
|
||||
from TestUtils import checkDerivative
|
||||
from scipy.sparse.linalg import dsolve
|
||||
|
||||
@@ -15,7 +16,7 @@ class ProblemTests(unittest.TestCase):
|
||||
c = np.array([1, 4])
|
||||
self.mesh2 = TensorMesh([a, b], np.array([3, 5]))
|
||||
self.p2 = Problem(self.mesh2)
|
||||
|
||||
self.reg = Regularization(self.mesh2)
|
||||
|
||||
def test_modelTransform(self):
|
||||
print 'SimPEG.forward.Problem: Testing Model Transform'
|
||||
@@ -23,6 +24,13 @@ class ProblemTests(unittest.TestCase):
|
||||
passed = checkDerivative(lambda m : [self.p2.modelTransform(m), self.p2.modelTransformDeriv(m)], m, plotIt=False)
|
||||
self.assertTrue(passed)
|
||||
|
||||
def test_regularization(self):
|
||||
derChk = lambda m: [self.reg.modelObj(m), self.reg.modelObjDeriv(m)]
|
||||
mSynth = np.random.randn(self.mesh2.nC)
|
||||
checkDerivative(derChk, mSynth, plotIt=False)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from TestUtils import OrderTest
|
||||
from SimPEG.utils import mkvc
|
||||
|
||||
MESHTYPES = ['uniformTensorMesh', 'randomTensorMesh']
|
||||
TOLERANCES = [0.9, 0.55]
|
||||
call1 = lambda fun, xyz: fun(xyz)
|
||||
call2 = lambda fun, xyz: fun(xyz[:, 0], xyz[:, 1])
|
||||
call3 = lambda fun, xyz: fun(xyz[:, 0], xyz[:, 1], xyz[:, 2])
|
||||
cart_row2 = lambda g, xfun, yfun: np.c_[call2(xfun, g), call2(yfun, g)]
|
||||
cart_row3 = lambda g, xfun, yfun, zfun: np.c_[call3(xfun, g), call3(yfun, g), call3(zfun, g)]
|
||||
cartF2 = lambda M, fx, fy: np.vstack((cart_row2(M.gridFx, fx, fy), cart_row2(M.gridFy, fx, fy)))
|
||||
cartE2 = lambda M, ex, ey: np.vstack((cart_row2(M.gridEx, ex, ey), cart_row2(M.gridEy, ex, ey)))
|
||||
cartF3 = lambda M, fx, fy, fz: np.vstack((cart_row3(M.gridFx, fx, fy, fz), cart_row3(M.gridFy, fx, fy, fz), cart_row3(M.gridFz, fx, fy, fz)))
|
||||
cartE3 = lambda M, ex, ey, ez: np.vstack((cart_row3(M.gridEx, ex, ey, ez), cart_row3(M.gridEy, ex, ey, ez), cart_row3(M.gridEz, ex, ey, ez)))
|
||||
|
||||
|
||||
|
||||
class TestInterpolation1D(OrderTest):
|
||||
LOCS = np.random.rand(50)*0.6+0.2
|
||||
name = "Interpolation 1D"
|
||||
meshTypes = MESHTYPES
|
||||
tolerance = TOLERANCES
|
||||
meshDimension = 1
|
||||
meshSizes = [8, 16, 32]
|
||||
|
||||
def getError(self):
|
||||
funX = lambda x: np.cos(2*np.pi*x)
|
||||
|
||||
anal = call1(funX, self.LOCS)
|
||||
|
||||
if 'CC' == self.type:
|
||||
grid = call1(funX, self.M.gridCC)
|
||||
elif 'N' == self.type:
|
||||
grid = call1(funX, self.M.gridN)
|
||||
|
||||
comp = self.M.getInterpolationMat(self.LOCS, self.type)*grid
|
||||
|
||||
err = np.linalg.norm((comp - anal), 2)
|
||||
return err
|
||||
|
||||
def test_orderCC(self):
|
||||
self.type = 'CC'
|
||||
self.name = 'Interpolation 1D: CC'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderN(self):
|
||||
self.type = 'N'
|
||||
self.name = 'Interpolation 1D: N'
|
||||
self.orderTest()
|
||||
|
||||
class TestInterpolation2d(OrderTest):
|
||||
name = "Interpolation 2D"
|
||||
LOCS = np.random.rand(50,2)*0.6+0.2
|
||||
meshTypes = MESHTYPES
|
||||
tolerance = TOLERANCES
|
||||
meshDimension = 2
|
||||
meshSizes = [8, 16, 32, 64]
|
||||
|
||||
def getError(self):
|
||||
funX = lambda x, y: np.cos(2*np.pi*y)
|
||||
funY = lambda x, y: np.cos(2*np.pi*x)
|
||||
|
||||
if 'x' in self.type:
|
||||
anal = call2(funX, self.LOCS)
|
||||
elif 'y' in self.type:
|
||||
anal = call2(funY, self.LOCS)
|
||||
else:
|
||||
anal = call2(funX, self.LOCS)
|
||||
|
||||
if 'F' in self.type:
|
||||
Fc = cartF2(self.M, funX, funY)
|
||||
grid = self.M.projectFaceVector(Fc)
|
||||
elif 'E' in self.type:
|
||||
Ec = cartE2(self.M, funX, funY)
|
||||
grid = self.M.projectEdgeVector(Ec)
|
||||
elif 'CC' == self.type:
|
||||
grid = call2(funX, self.M.gridCC)
|
||||
elif 'N' == self.type:
|
||||
grid = call2(funX, self.M.gridN)
|
||||
|
||||
comp = self.M.getInterpolationMat(self.LOCS, self.type)*grid
|
||||
|
||||
err = np.linalg.norm((comp - anal), np.inf)
|
||||
return err
|
||||
|
||||
def test_orderCC(self):
|
||||
self.type = 'CC'
|
||||
self.name = 'Interpolation 2D: CC'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderN(self):
|
||||
self.type = 'N'
|
||||
self.name = 'Interpolation 2D: N'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderFx(self):
|
||||
self.type = 'Fx'
|
||||
self.name = 'Interpolation 2D: Fx'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderFy(self):
|
||||
self.type = 'Fy'
|
||||
self.name = 'Interpolation 2D: Fy'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderEx(self):
|
||||
self.type = 'Ex'
|
||||
self.name = 'Interpolation 2D: Ex'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderEy(self):
|
||||
self.type = 'Ey'
|
||||
self.name = 'Interpolation 2D: Ey'
|
||||
self.orderTest()
|
||||
|
||||
|
||||
|
||||
class TestInterpolation3D(OrderTest):
|
||||
name = "Interpolation"
|
||||
LOCS = np.random.rand(50,3)*0.6+0.2
|
||||
meshTypes = MESHTYPES
|
||||
tolerance = TOLERANCES
|
||||
meshDimension = 3
|
||||
meshSizes = [8, 16, 32, 64]
|
||||
|
||||
def getError(self):
|
||||
funX = lambda x, y, z: np.cos(2*np.pi*y)
|
||||
funY = lambda x, y, z: np.cos(2*np.pi*z)
|
||||
funZ = lambda x, y, z: np.cos(2*np.pi*x)
|
||||
|
||||
if 'x' in self.type:
|
||||
anal = call3(funX, self.LOCS)
|
||||
elif 'y' in self.type:
|
||||
anal = call3(funY, self.LOCS)
|
||||
elif 'z' in self.type:
|
||||
anal = call3(funZ, self.LOCS)
|
||||
else:
|
||||
anal = call3(funX, self.LOCS)
|
||||
|
||||
if 'F' in self.type:
|
||||
Fc = cartF3(self.M, funX, funY, funZ)
|
||||
grid = self.M.projectFaceVector(Fc)
|
||||
elif 'E' in self.type:
|
||||
Ec = cartE3(self.M, funX, funY, funZ)
|
||||
grid = self.M.projectEdgeVector(Ec)
|
||||
elif 'CC' == self.type:
|
||||
grid = call3(funX, self.M.gridCC)
|
||||
elif 'N' == self.type:
|
||||
grid = call3(funX, self.M.gridN)
|
||||
|
||||
comp = self.M.getInterpolationMat(self.LOCS, self.type)*grid
|
||||
|
||||
err = np.linalg.norm((comp - anal), np.inf)
|
||||
return err
|
||||
|
||||
def test_orderCC(self):
|
||||
self.type = 'CC'
|
||||
self.name = 'Interpolation CC'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderN(self):
|
||||
self.type = 'N'
|
||||
self.name = 'Interpolation N'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderFx(self):
|
||||
self.type = 'Fx'
|
||||
self.name = 'Interpolation Fx'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderFy(self):
|
||||
self.type = 'Fy'
|
||||
self.name = 'Interpolation Fy'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderFz(self):
|
||||
self.type = 'Fz'
|
||||
self.name = 'Interpolation Fz'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderEx(self):
|
||||
self.type = 'Ex'
|
||||
self.name = 'Interpolation Ex'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderEy(self):
|
||||
self.type = 'Ey'
|
||||
self.name = 'Interpolation Ey'
|
||||
self.orderTest()
|
||||
|
||||
def test_orderEz(self):
|
||||
self.type = 'Ez'
|
||||
self.name = 'Interpolation Ez'
|
||||
self.orderTest()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,6 +1,28 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from SimPEG.utils import mkvc, ndgrid, indexCube, sdiag, inv3X3BlockDiagonal, inv2X2BlockDiagonal
|
||||
from SimPEG.tests import checkDerivative
|
||||
|
||||
|
||||
class TestCheckDerivative(unittest.TestCase):
|
||||
|
||||
def test_simplePass(self):
|
||||
def simplePass(x):
|
||||
return np.sin(x), sdiag(np.cos(x))
|
||||
passed = checkDerivative(simplePass, np.random.randn(5), plotIt=False)
|
||||
self.assertTrue(passed, True)
|
||||
|
||||
def test_simpleFunction(self):
|
||||
def simpleFunction(x):
|
||||
return np.sin(x), lambda xi: sdiag(np.cos(x))*xi
|
||||
passed = checkDerivative(simpleFunction, np.random.randn(5), plotIt=False)
|
||||
self.assertTrue(passed, True)
|
||||
|
||||
def test_simpleFail(self):
|
||||
def simpleFail(x):
|
||||
return np.sin(x), -sdiag(np.cos(x))
|
||||
passed = checkDerivative(simpleFail, np.random.randn(5), plotIt=False)
|
||||
self.assertTrue(not passed, True)
|
||||
|
||||
|
||||
class TestSequenceFunctions(unittest.TestCase):
|
||||
@@ -85,5 +107,6 @@ class TestSequenceFunctions(unittest.TestCase):
|
||||
self.assertTrue(np.linalg.norm(Z3.todense().ravel(), 2) < 1e-12)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sparse
|
||||
import scipy.sparse.linalg as linalg
|
||||
|
||||
|
||||
class Solver(object):
|
||||
"""
|
||||
Solver is a light wrapper on the various types of
|
||||
linear solvers available in python.
|
||||
|
||||
:param scipy.sparse A: Matrix
|
||||
:param bool doDirect: if you want a direct solver
|
||||
:param string flag: Matrix type flag for special solves: [None, 'L', 'U', 'D']
|
||||
:param dict options: options which are passed to each sub solver, see each for details.
|
||||
:rtype: Solver
|
||||
:return: Solver
|
||||
|
||||
To use for direct solvers::
|
||||
|
||||
solve = Solver(A, doDirect=True, flag=None, options={'factorize':True,'backend':'scipy'})
|
||||
x = solve.solve(rhs)
|
||||
|
||||
Or in one line::
|
||||
|
||||
x = Solver(A).solve(rhs)
|
||||
|
||||
The flag can be set to None, 'L', 'U', or 'D', for general, lower, upper, and diagonal matrices, respectively.
|
||||
|
||||
"""
|
||||
def __init__(self, A, doDirect=True, flag=None, options={}):
|
||||
assert type(doDirect) is bool, 'doDirect must be a boolean'
|
||||
assert flag in [None, 'L', 'U', 'D'], "flag must be set to None, 'L', 'U', or 'D'"
|
||||
|
||||
self.A = A
|
||||
|
||||
self.dsolve = None
|
||||
self.doDirect = doDirect
|
||||
self.flag = flag
|
||||
self.options = options
|
||||
|
||||
def solve(self, b):
|
||||
"""
|
||||
Solves the linear system.
|
||||
|
||||
.. math::
|
||||
|
||||
Ax=b
|
||||
|
||||
:param numpy.ndarray b: the right hand side
|
||||
:rtype: numpy.ndarray
|
||||
:return: x
|
||||
"""
|
||||
if self.flag is None and self.doDirect:
|
||||
return self.solveDirect(b, **self.options)
|
||||
elif self.flag is None and not self.doDirect:
|
||||
return self.solveIter(b, **self.options)
|
||||
elif self.flag == 'U':
|
||||
return self.solveBackward(b)
|
||||
elif self.flag == 'L':
|
||||
return self.solveForward(b)
|
||||
elif self.flag == 'D':
|
||||
return self.solveDiagonal(b)
|
||||
else:
|
||||
raise Exception('Unknown flag.')
|
||||
pass
|
||||
|
||||
def clean(self):
|
||||
"""Cleans up the memory"""
|
||||
del self.dsolve
|
||||
self.dsolve = None
|
||||
|
||||
def solveDirect(self, b, factorize=False, backend='scipy'):
|
||||
"""
|
||||
Use solve instead of this interface.
|
||||
|
||||
:param bool factorize: if you want to factorize and store factors
|
||||
:param str backend: which backend to use. Default is scipy
|
||||
:rtype: numpy.ndarray
|
||||
:return: x
|
||||
"""
|
||||
assert np.shape(self.A)[1] == np.shape(b)[0], 'Dimension mismatch'
|
||||
|
||||
if factorize and self.dsolve is None:
|
||||
self.A = self.A.tocsc() # for efficiency
|
||||
self.dsolve = linalg.factorized(self.A)
|
||||
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
if factorize:
|
||||
return self.dsolve(b)
|
||||
else:
|
||||
return linalg.dsolve.spsolve(self.A, b)
|
||||
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
if factorize:
|
||||
X[:,i] = self.dsolve(b[:,i])
|
||||
else:
|
||||
X[:,i] = linalg.dsolve.spsolve(self.A,b[:,i])
|
||||
|
||||
return X
|
||||
|
||||
def solveIter(self, b, M=None, iterSolver='CG'):
|
||||
pass
|
||||
|
||||
def solveBackward(self, b, backend='python'):
|
||||
"""
|
||||
Use solve instead of this interface.
|
||||
|
||||
Perform a backwards solve with upper triangular A in CSR format (best, if not, it will be converted).
|
||||
|
||||
:param str backend: which backend to use. Default is python.
|
||||
:rtype: numpy.ndarray
|
||||
:return: x
|
||||
"""
|
||||
if type(self.A) is not sparse.csr.csr_matrix:
|
||||
from scipy.sparse import csr_matrix
|
||||
self.A = csr_matrix(self.A)
|
||||
vals = self.A.data
|
||||
rowptr = self.A.indptr
|
||||
colind = self.A.indices
|
||||
x = np.empty_like(b) # empty() is faster than zeros().
|
||||
for i in reversed(xrange(self.A.shape[0])):
|
||||
ith_row = vals[rowptr[i] : rowptr[i+1]]
|
||||
cols = colind[rowptr[i] : rowptr[i+1]]
|
||||
x_vals = x[cols]
|
||||
x[i] = (b[i] - np.dot(ith_row[1:], x_vals[1:])) / ith_row[0]
|
||||
return x
|
||||
|
||||
def solveForward(self, b, backend='python'):
|
||||
"""
|
||||
Use solve instead of this interface.
|
||||
|
||||
Perform a forward solve with lower triangular A in CSR format (best, if not, it will be converted).
|
||||
|
||||
:param str backend: which backend to use. Default is python.
|
||||
:rtype: numpy.ndarray
|
||||
:return: x
|
||||
"""
|
||||
if type(self.A) is not sparse.csr.csr_matrix:
|
||||
from scipy.sparse import csr_matrix
|
||||
self.A = csr_matrix(self.A)
|
||||
vals = self.A.data
|
||||
rowptr = self.A.indptr
|
||||
colind = self.A.indices
|
||||
x = np.empty_like(b) # empty() is faster than zeros().
|
||||
for i in xrange(self.A.shape[0]):
|
||||
ith_row = vals[rowptr[i] : rowptr[i+1]]
|
||||
cols = colind[rowptr[i] : rowptr[i+1]]
|
||||
x_vals = x[cols]
|
||||
x[i] = (b[i] - np.dot(ith_row[:-1], x_vals[:-1])) / ith_row[-1]
|
||||
return x
|
||||
|
||||
def solveDiagonal(self, b, backend='python'):
|
||||
"""
|
||||
Use solve instead of this interface.
|
||||
|
||||
Perform a diagonal solve with diagonal matrix A.
|
||||
|
||||
:param str backend: which backend to use. Default is python.
|
||||
:rtype: numpy.ndarray
|
||||
:return: x
|
||||
"""
|
||||
diagA = self.A.diagonal()
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
return b/diagA
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
X[:,i] = b[:,i]/diagA
|
||||
return X
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from time import time
|
||||
h1 = np.ones(20)*100.
|
||||
h2 = np.ones(20)*100.
|
||||
h3 = np.ones(20)*100.
|
||||
|
||||
h = [h1,h2,h3]
|
||||
|
||||
M = TensorMesh(h)
|
||||
|
||||
D = M.faceDiv
|
||||
G = M.cellGrad
|
||||
Msig = M.getFaceMass()
|
||||
A = D*Msig*G
|
||||
A[0,0] *= 10 # remove the constant null space from the matrix
|
||||
|
||||
e = np.ones(M.nC)
|
||||
rhs = A.dot(e)
|
||||
|
||||
tic = time()
|
||||
solve = Solver(A, options={'factorize':True})
|
||||
x = solve.solve(rhs)
|
||||
print 'Factorized', time() - tic
|
||||
print np.linalg.norm(e-x,np.inf)
|
||||
tic = time()
|
||||
solve = Solver(A, options={'factorize':False})
|
||||
x = solve.solve(rhs)
|
||||
print 'spsolve', time() - tic
|
||||
print np.linalg.norm(e-x,np.inf)
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import matutils
|
||||
import sputils
|
||||
import lomutils
|
||||
import interputils
|
||||
import ModelBuilder
|
||||
import Solver
|
||||
from Solver import Solver
|
||||
from matutils import getSubArray, mkvc, ndgrid, ind2sub, sub2ind
|
||||
from sputils import spzeros, kron3, speye, sdiag
|
||||
from lomutils import volTetra, faceInfo, inv2X2BlockDiagonal, inv3X3BlockDiagonal, indexCube, exampleLomGird
|
||||
from interputils import interpmat
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from sputils import spzeros
|
||||
from matutils import mkvc, sub2ind
|
||||
|
||||
def _interp_point_1D(x, xr_i):
|
||||
"""
|
||||
given a point, xr_i, this will find which two integers it lies between.
|
||||
|
||||
:param numpy.ndarray x: Tensor vector of 1st dimension of grid.
|
||||
:param float xr_i: Location of a point
|
||||
:rtype: int,int,float,float
|
||||
:return: index1, index2, portion1, portion2
|
||||
"""
|
||||
# TODO: This fails if the point is on the outside of the mesh. We may want to replace this by extrapolation?
|
||||
im = np.argmin(abs(x-xr_i))
|
||||
if xr_i - x[im] >= 0: # Point on the left
|
||||
ind_x1 = im
|
||||
ind_x2 = im+1
|
||||
elif xr_i - x[im] < 0: # Point on the right
|
||||
ind_x1 = im-1
|
||||
ind_x2 = im
|
||||
dx1 = xr_i - x[ind_x1]
|
||||
dx2 = x[ind_x2] - xr_i
|
||||
return ind_x1, ind_x2, dx1, dx2
|
||||
|
||||
|
||||
def interpmat(locs, x, y=None, z=None):
|
||||
"""
|
||||
Local interpolation computed for each receiver point in turn
|
||||
|
||||
:param numpy.ndarray loc: Location of points to interpolate to
|
||||
:param numpy.ndarray x: Tensor vector of 1st dimension of grid.
|
||||
:param numpy.ndarray y: Tensor vector of 2nd dimension of grid. None by default.
|
||||
:param numpy.ndarray z: Tensor vector of 3rd dimension of grid. None by default.
|
||||
:rtype: scipy.sparse.csr.csr_matrix
|
||||
:return: Interpolation matrix
|
||||
|
||||
.. plot::
|
||||
|
||||
import SimPEG
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
locs = np.random.rand(50)*0.8+0.1
|
||||
x = np.linspace(0,1,7)
|
||||
dense = np.linspace(0,1,200)
|
||||
fun = lambda x: np.cos(2*np.pi*x)
|
||||
Q = SimPEG.utils.interpmat(locs, x)
|
||||
plt.plot(x, fun(x), 'bs-')
|
||||
plt.plot(dense, fun(dense), 'y:')
|
||||
plt.plot(locs, Q*fun(x), 'mo')
|
||||
plt.plot(locs, fun(locs), 'rx')
|
||||
plt.show()
|
||||
|
||||
"""
|
||||
if y is None and z is None:
|
||||
return _interpmat1D(locs, x)
|
||||
elif z is None:
|
||||
return _interpmat2D(locs, x, y)
|
||||
else:
|
||||
return _interpmat3D(locs, x, y, z)
|
||||
|
||||
|
||||
def _interpmat1D(locs, x):
|
||||
"""Use interpmat with only x component provided."""
|
||||
nx = x.size
|
||||
locs = mkvc(locs)
|
||||
npts = locs.shape[0]
|
||||
|
||||
Q = sp.lil_matrix((npts, nx))
|
||||
|
||||
for i in range(npts):
|
||||
ind_x1, ind_x2, dx1, dx2 = _interp_point_1D(x, locs[i])
|
||||
dv = (x[ind_x2] - x[ind_x1])
|
||||
Dx = x[ind_x2] - x[ind_x1]
|
||||
# Get the row in the matrix
|
||||
inds = [ind_x1, ind_x2]
|
||||
vals = [(1-dx1/Dx),(1-dx2/Dx)]
|
||||
Q[i, inds] = vals
|
||||
return Q.tocsr()
|
||||
|
||||
|
||||
|
||||
def _interpmat2D(locs, x, y):
|
||||
"""Use interpmat with only x and y components provided."""
|
||||
nx = x.size
|
||||
ny = y.size
|
||||
npts = locs.shape[0]
|
||||
|
||||
Q = sp.lil_matrix((npts, nx*ny))
|
||||
|
||||
|
||||
for i in range(npts):
|
||||
ind_x1, ind_x2, dx1, dx2 = _interp_point_1D(x, locs[i, 0])
|
||||
ind_y1, ind_y2, dy1, dy2 = _interp_point_1D(y, locs[i, 1])
|
||||
|
||||
dv = (x[ind_x2] - x[ind_x1]) * (y[ind_y2] - y[ind_y1])
|
||||
|
||||
Dx = x[ind_x2] - x[ind_x1]
|
||||
Dy = y[ind_y2] - y[ind_y1]
|
||||
|
||||
# Get the row in the matrix
|
||||
|
||||
inds = sub2ind((nx,ny),[
|
||||
( ind_x1, ind_y2),
|
||||
( ind_x1, ind_y1),
|
||||
( ind_x2, ind_y1),
|
||||
( ind_x2, ind_y2)])
|
||||
|
||||
vals = [(1-dx1/Dx)*(1-dy2/Dy),
|
||||
(1-dx1/Dx)*(1-dy1/Dy),
|
||||
(1-dx2/Dx)*(1-dy1/Dy),
|
||||
(1-dx2/Dx)*(1-dy2/Dy)]
|
||||
|
||||
Q[i, mkvc(inds)] = vals
|
||||
|
||||
return Q.tocsr()
|
||||
|
||||
|
||||
|
||||
def _interpmat3D(locs, x, y, z):
|
||||
"""Use interpmat."""
|
||||
nx = x.size
|
||||
ny = y.size
|
||||
nz = z.size
|
||||
npts = locs.shape[0]
|
||||
|
||||
Q = sp.lil_matrix((npts, nx*ny*nz))
|
||||
|
||||
|
||||
for i in range(npts):
|
||||
ind_x1, ind_x2, dx1, dx2 = _interp_point_1D(x, locs[i, 0])
|
||||
ind_y1, ind_y2, dy1, dy2 = _interp_point_1D(y, locs[i, 1])
|
||||
ind_z1, ind_z2, dz1, dz2 = _interp_point_1D(z, locs[i, 2])
|
||||
|
||||
dv = (x[ind_x2] - x[ind_x1]) * (y[ind_y2] - y[ind_y1]) *(z[ind_z2] - z[ind_z1])
|
||||
|
||||
Dx = x[ind_x2] - x[ind_x1]
|
||||
Dy = y[ind_y2] - y[ind_y1]
|
||||
Dz = z[ind_z2] - z[ind_z1]
|
||||
|
||||
# Get the row in the matrix
|
||||
|
||||
inds = sub2ind((nx,ny,nz),[
|
||||
( ind_x1, ind_y2, ind_z1),
|
||||
( ind_x1, ind_y1, ind_z1),
|
||||
( ind_x2, ind_y1, ind_z1),
|
||||
( ind_x2, ind_y2, ind_z1),
|
||||
( ind_x1, ind_y1, ind_z2),
|
||||
( ind_x1, ind_y2, ind_z2),
|
||||
( ind_x2, ind_y1, ind_z2),
|
||||
( ind_x2, ind_y2, ind_z2)])
|
||||
|
||||
vals = [(1-dx1/Dx)*(1-dy2/Dy)*(1-dz1/Dz),
|
||||
(1-dx1/Dx)*(1-dy1/Dy)*(1-dz1/Dz),
|
||||
(1-dx2/Dx)*(1-dy1/Dy)*(1-dz1/Dz),
|
||||
(1-dx2/Dx)*(1-dy2/Dy)*(1-dz1/Dz),
|
||||
(1-dx1/Dx)*(1-dy1/Dy)*(1-dz2/Dz),
|
||||
(1-dx1/Dx)*(1-dy2/Dy)*(1-dz2/Dz),
|
||||
(1-dx2/Dx)*(1-dy1/Dy)*(1-dz2/Dz),
|
||||
(1-dx2/Dx)*(1-dy2/Dy)*(1-dz2/Dz)]
|
||||
|
||||
Q[i, mkvc(inds)] = vals
|
||||
|
||||
return Q.tocsr()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import SimPEG
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
locs = np.random.rand(50)*0.8+0.1
|
||||
x = np.linspace(0,1,7)
|
||||
dense = np.linspace(0,1,200)
|
||||
fun = lambda x: np.cos(2*np.pi*x)
|
||||
Q = SimPEG.utils.interpmat(locs, x)
|
||||
plt.plot(x, fun(x), 'bs-')
|
||||
plt.plot(dense, fun(dense), 'y:')
|
||||
plt.plot(locs, Q*fun(x), 'mo')
|
||||
plt.plot(locs, fun(locs), 'rx')
|
||||
plt.show()
|
||||
@@ -1,8 +0,0 @@
|
||||
.. _api_LOMView:
|
||||
|
||||
LOM View
|
||||
********
|
||||
|
||||
.. automodule:: SimPEG.mesh.LomView
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -6,3 +6,11 @@ Logically Orthogonal Mesh
|
||||
.. automodule:: SimPEG.mesh.LogicallyOrthogonalMesh
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
LOM View
|
||||
********
|
||||
|
||||
.. automodule:: SimPEG.mesh.LomView
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@@ -6,3 +6,18 @@ Optimize
|
||||
.. automodule:: SimPEG.inverse.Optimize
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
Inversion
|
||||
*********
|
||||
|
||||
.. automodule:: SimPEG.inverse.Inversion
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Beta Schedule
|
||||
*************
|
||||
|
||||
.. automodule:: SimPEG.inverse.BetaSchedule
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@@ -13,14 +13,16 @@ Problem
|
||||
DCProblem
|
||||
*********
|
||||
|
||||
.. automodule:: SimPEG.forward.DCProblem.DCProblem
|
||||
.. automodule:: SimPEG.forward.DCProblem
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
DCutils
|
||||
*******
|
||||
|
||||
.. automodule:: SimPEG.forward.DCProblem.DCutils
|
||||
Linear Problem
|
||||
**************
|
||||
|
||||
.. automodule:: SimPEG.forward.LinearProblem
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
.. _api_Solver:
|
||||
|
||||
Solver
|
||||
******
|
||||
|
||||
.. automodule:: SimPEG.utils.Solver
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@@ -6,3 +6,10 @@ Tensor Mesh
|
||||
.. automodule:: SimPEG.mesh.TensorMesh
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Tensor View
|
||||
***********
|
||||
|
||||
.. automodule:: SimPEG.mesh.TensorView
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
.. _api_TensorView:
|
||||
|
||||
Tensor View
|
||||
***********
|
||||
|
||||
.. automodule:: SimPEG.mesh.TensorView
|
||||
:members:
|
||||
:undoc-members:
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,3 +19,7 @@ Utilities
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. automodule:: SimPEG.utils.interputils
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from SimPEG import LogicallyOrthogonalMesh, utils
|
||||
from SimPEG.mesh import LogicallyOrthogonalMesh
|
||||
from SimPEG import utils
|
||||
import matplotlib.pyplot as plt
|
||||
X, Y = utils.exampleLomGird([3,3],'rotate')
|
||||
M = LogicallyOrthogonalMesh([X, Y])
|
||||
|
||||
+2
-7
@@ -1,8 +1,3 @@
|
||||
.. SimPEG documentation master file, created by
|
||||
sphinx-quickstart on Fri Aug 30 18:42:44 2013.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
SimPEG
|
||||
======
|
||||
|
||||
@@ -24,10 +19,8 @@ Meshing & Operators
|
||||
|
||||
api_BaseMesh
|
||||
api_TensorMesh
|
||||
api_TensorView
|
||||
api_LogicallyOrthogonalMesh
|
||||
api_Cyl1DMesh
|
||||
api_LOMView
|
||||
api_DiffOperators
|
||||
api_InnerProducts
|
||||
|
||||
@@ -54,6 +47,7 @@ Testing SimPEG
|
||||
:maxdepth: 2
|
||||
|
||||
api_Tests
|
||||
api_TestResults
|
||||
|
||||
|
||||
Utility Codes
|
||||
@@ -62,6 +56,7 @@ Utility Codes
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
api_Solver
|
||||
api_Utils
|
||||
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
numpy
|
||||
pypubsub
|
||||
|
||||
Reference in New Issue
Block a user