mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 11:55:17 +08:00
first pass at multisrc Jtvec (will be hugely memory inefficient at the moment)
This commit is contained in:
+25
-20
@@ -158,17 +158,18 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
|
||||
PT_v = Fields_Derivs(self.mesh, self.survey) #PT_v is a fields object
|
||||
|
||||
if ftype is 'bSolution' or 'jSolution':
|
||||
df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
elif ftype is 'eSolution' or 'hSolution':
|
||||
df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
ATinv_df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
df_duT_v = Fields_Derivs(self.mesh, self.survey)
|
||||
ATinv_df_duT_v = Fields_Derivs(self.mesh, self.survey)
|
||||
# if ftype is 'bSolution' or 'jSolution':
|
||||
# # df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
# ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
# elif ftype is 'eSolution' or 'hSolution':
|
||||
# # df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
# ATinv_df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
|
||||
JTv = np.zeros(m.shape)
|
||||
|
||||
|
||||
|
||||
# Loop over sources and receivers to create a fields object: PT_v, df_duT_v, df_dmT_v
|
||||
for src in self.survey.srcList:
|
||||
# initialize empty fields derivs
|
||||
@@ -188,7 +189,7 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
df_duT_v_cur, df_dmT_v = df_duTFun(None, src, None, PT_v[src,'%sDeriv'%projField,:], adjoint=True)
|
||||
|
||||
JTv = JTv + df_dmT_v
|
||||
df_duT_v = df_duT_v + df_duT_v_cur
|
||||
df_duT_v[src, '%sDeriv'%self._fieldType, :] = df_duT_v_cur
|
||||
|
||||
|
||||
AdiagTinv = None
|
||||
@@ -204,28 +205,32 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
Adiag = self.getAdiag(tInd)
|
||||
AdiagTinv = self.Solver(Adiag.T, **self.solverOpts)
|
||||
|
||||
# solve against df_duT_v
|
||||
if tInd >= self.nT-1:
|
||||
ATinv_df_duT_v[:,tInd+1] = AdiagTinv * df_duT_v[:,tInd+1]
|
||||
else:
|
||||
if tInd < self.nT - 1:
|
||||
Asubdiag = self.getAsubdiag(tInd+1)
|
||||
ATinv_df_duT_v[:,tInd+1] = AdiagTinv * (df_duT_v[:,tInd+1] - Asubdiag.T * ATinv_df_duT_v[:,tInd+2])
|
||||
|
||||
|
||||
for src in self.survey.srcList:
|
||||
un_src = u[src,ftype,tInd+1]
|
||||
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # cell centered on time mesh
|
||||
# solve against df_duT_v
|
||||
if tInd >= self.nT-1:
|
||||
ATinv_df_duT_v[src,'%sDeriv'%self._fieldType,tInd+1] = AdiagTinv * df_duT_v[src,'%sDeriv'%self._fieldType,tInd+1]
|
||||
else:
|
||||
ATinv_df_duT_v[src,'%sDeriv'%self._fieldType,tInd+1] = AdiagTinv * (Utils.mkvc(df_duT_v[src,'%sDeriv'%self._fieldType,tInd+1]) - Asubdiag.T * Utils.mkvc(ATinv_df_duT_v[src,'%sDeriv'%self._fieldType,tInd+2]))
|
||||
|
||||
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # on nodes of time mesh
|
||||
un_src = u[src,ftype,tInd+1]
|
||||
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[src, '%sDeriv'%self._fieldType,tInd+1], adjoint=True) # cell centered on time mesh
|
||||
|
||||
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[src, '%sDeriv'%self._fieldType,tInd+1], adjoint=True) # on nodes of time mesh
|
||||
# dAsubdiag_dm_v = 0
|
||||
|
||||
JTv = JTv + (-dAT_dm_v + dRHST_dm_v)
|
||||
JTv = JTv + Utils.mkvc(-dAT_dm_v + dRHST_dm_v)
|
||||
|
||||
# this doesn't include initial fields deriv
|
||||
|
||||
# adding du_dm^T * dF_du^T * P^T vfor time 0 (no dRHS_dm_v at time 0)
|
||||
JTv = JTv + self.getInitialFieldsDeriv(df_duT_v[:,0], adjoint=True)
|
||||
for src in self.survey.srcList:
|
||||
JTv = JTv + self.getInitialFieldsDeriv(Utils.mkvc(df_duT_v[src,'%sDeriv'%self._fieldType,0]), adjoint=True)
|
||||
|
||||
return JTv
|
||||
|
||||
return Utils.mkvc(JTv)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@ from SimPEG import *
|
||||
from SimPEG import EM
|
||||
|
||||
plotIt = False
|
||||
testDeriv = True
|
||||
testAdjoint = True
|
||||
|
||||
|
||||
class TDEM_bDerivTests(unittest.TestCase):
|
||||
|
||||
@@ -21,14 +24,14 @@ class TDEM_bDerivTests(unittest.TestCase):
|
||||
mapping = Maps.ExpMap(mesh) * Maps.SurjectVertical1D(mesh) * activeMap
|
||||
|
||||
rxOffset = 40.
|
||||
rx = EM.TDEM.RxTDEM(np.array([[rxOffset, 0., 0.]]), np.logspace(-4,-3, 20), 'bz')
|
||||
src = EM.TDEM.SrcTDEM_VMD_MVP( [rx], loc=np.array([0., 0., 0.]))
|
||||
rx2 = EM.TDEM.RxTDEM(np.array([[rxOffset-10, 0., 0.]]), np.logspace(-5,-4, 25), 'bz')
|
||||
src2 = EM.TDEM.SrcTDEM_VMD_MVP( [rx2], loc=np.array([0., 0., 0.]))
|
||||
rx = EM.TDEM.Rx(np.array([[rxOffset, 0., 0.]]), np.logspace(-4,-3, 20), 'bz')
|
||||
src = EM.TDEM.SurveyTDEM.MagDipole( [rx], loc=np.array([0., 0., 0.]))
|
||||
rx2 = EM.TDEM.Rx(np.array([[rxOffset-10, 0., 0.]]), np.logspace(-5,-4, 25), 'bz')
|
||||
src2 = EM.TDEM.SurveyTDEM.MagDipole( [rx2], loc=np.array([0., 0., 0.]))
|
||||
|
||||
survey = EM.TDEM.SurveyTDEM([src,src2])
|
||||
survey = EM.TDEM.Survey([src,src2])
|
||||
|
||||
self.prb = EM.TDEM.ProblemTDEM_b(mesh, mapping=mapping)
|
||||
self.prb = EM.TDEM.Problem_b(mesh, mapping=mapping)
|
||||
# self.prb.timeSteps = [1e-5]
|
||||
self.prb.timeSteps = [(1e-05, 10), (5e-05, 10), (2.5e-4, 10)]
|
||||
# self.prb.timeSteps = [(1e-05, 100)]
|
||||
@@ -39,115 +42,38 @@ class TDEM_bDerivTests(unittest.TestCase):
|
||||
except ImportError, e:
|
||||
self.prb.Solver = SolverLU
|
||||
|
||||
self.sigma = np.ones(mesh.nCz)*1e-8
|
||||
self.sigma[mesh.vectorCCz<0] = 1e-1
|
||||
self.sigma = np.log(self.sigma[active])
|
||||
self.m = np.log(1e-1)*np.ones(self.prb.mapping.nP) + 1e-2*np.random.randn(self.prb.mapping.nP)
|
||||
|
||||
self.prb.pair(survey)
|
||||
self.mesh = mesh
|
||||
|
||||
def test_DerivG(self):
|
||||
"""
|
||||
Test the derivative of c with respect to sigma
|
||||
"""
|
||||
if testDeriv:
|
||||
def test_Deriv_J(self):
|
||||
|
||||
# Random model and perturbation
|
||||
sigma = np.random.rand(self.prb.mapping.nP)
|
||||
prb = self.prb
|
||||
prb.timeSteps = [(1e-05, 10), (0.0001, 10), (0.001, 10)]
|
||||
mesh = self.mesh
|
||||
|
||||
f = self.prb.fields(sigma)
|
||||
dm = 1000*np.random.rand(self.prb.mapping.nP)
|
||||
h = 0.01
|
||||
derChk = lambda m: [prb.survey.dpred(m), lambda mx: prb.Jvec(self.m, mx)]
|
||||
print '\n'
|
||||
print 'test_Deriv_J'
|
||||
Tests.checkDerivative(derChk, self.m, plotIt=False, num=3, eps=1e-20)
|
||||
|
||||
derChk = lambda m: [self.prb._AhVec(m, f).tovec(), lambda mx: self.prb.Gvec(sigma, mx, u=f).tovec()]
|
||||
print '\ntest_DerivG'
|
||||
Tests.checkDerivative(derChk, sigma, plotIt=False, dx=dm, num=4, eps=1e-20)
|
||||
if testAdjoint:
|
||||
def test_adjointJvecVsJtvec(self):
|
||||
mesh = self.mesh
|
||||
prb = self.prb
|
||||
m0 = self.m
|
||||
|
||||
# def test_Deriv_dUdM(self):
|
||||
m = np.random.rand(prb.mapping.nP)
|
||||
d = np.random.rand(prb.survey.nD)
|
||||
|
||||
# prb = self.prb
|
||||
# prb.timeSteps = [(1e-05, 10), (0.0001, 10), (0.001, 10)]
|
||||
# mesh = self.mesh
|
||||
# sigma = self.sigma
|
||||
|
||||
# dm = 10*np.random.rand(prb.mapping.nP)
|
||||
# f = prb.fields(sigma)
|
||||
|
||||
# derChk = lambda m: [self.prb.fields(m).tovec(), lambda mx: -prb.solveAh(sigma, prb.Gvec(sigma, mx, u=f)).tovec()]
|
||||
# print '\n'
|
||||
# print 'test_Deriv_dUdM'
|
||||
# Tests.checkDerivative(derChk, sigma, plotIt=False, dx=dm, num=4, eps=1e-20)
|
||||
|
||||
# def test_Deriv_J(self):
|
||||
|
||||
# prb = self.prb
|
||||
# prb.timeSteps = [(1e-05, 10), (0.0001, 10), (0.001, 10)]
|
||||
# mesh = self.mesh
|
||||
# sigma = self.sigma
|
||||
|
||||
# # d_sig = 0.8*sigma #np.random.rand(mesh.nCz)
|
||||
# d_sig = 10*np.random.rand(prb.mapping.nP)
|
||||
|
||||
|
||||
# derChk = lambda m: [prb.survey.dpred(m), lambda mx: prb.Jvec(sigma, mx)]
|
||||
# print '\n'
|
||||
# print 'test_Deriv_J'
|
||||
# Tests.checkDerivative(derChk, sigma, plotIt=False, dx=d_sig, num=4, eps=1e-20)
|
||||
|
||||
# def test_projectAdjoint(self):
|
||||
# prb = self.prb
|
||||
# survey = prb.survey
|
||||
# nSrc = survey.nSrc
|
||||
# mesh = self.mesh
|
||||
|
||||
# # Generate random fields and data
|
||||
# f = EM.TDEM.FieldsTDEM(prb.mesh, prb.survey)
|
||||
# for i in range(prb.nT):
|
||||
# f[:,'b',i] = np.random.rand(mesh.nF, nSrc)
|
||||
# f[:,'e',i] = np.random.rand(mesh.nE, nSrc)
|
||||
# d_vec = np.random.rand(survey.nD)
|
||||
# d = Survey.Data(survey,v=d_vec)
|
||||
|
||||
# # Check that d.T*Q*f = f.T*Q.T*d
|
||||
# V1 = d_vec.dot(survey.evalDeriv(None, v=f).tovec())
|
||||
# V2 = np.sum((f.tovec())*(survey.evalDeriv(None, v=d, adjoint=True).tovec()))
|
||||
|
||||
# self.assertTrue((V1-V2)/np.abs(V1) < 1e-6)
|
||||
|
||||
# def test_adjointGvecVsGtvec(self):
|
||||
# mesh = self.mesh
|
||||
# prb = self.prb
|
||||
|
||||
# m = np.random.rand(prb.mapping.nP)
|
||||
# sigma = np.random.rand(prb.mapping.nP)
|
||||
|
||||
# u = EM.TDEM.FieldsTDEM(prb.mesh, prb.survey)
|
||||
# for i in range(1,prb.nT+1):
|
||||
# u[:,'b',i] = np.random.rand(mesh.nF, 2)
|
||||
# u[:,'e',i] = np.random.rand(mesh.nE, 2)
|
||||
|
||||
# v = EM.TDEM.FieldsTDEM(prb.mesh, prb.survey)
|
||||
# for i in range(1,prb.nT+1):
|
||||
# v[:,'b',i] = np.random.rand(mesh.nF, 2)
|
||||
# v[:,'e',i] = np.random.rand(mesh.nE, 2)
|
||||
|
||||
# V1 = m.dot(prb.Gtvec(sigma, v, u))
|
||||
# V2 = np.sum(v.tovec()*prb.Gvec(sigma, m, u).tovec())
|
||||
# self.assertTrue(np.abs(V1-V2)/np.abs(V1) <1e-6)
|
||||
|
||||
# def test_adjointJvecVsJtvec(self):
|
||||
# mesh = self.mesh
|
||||
# prb = self.prb
|
||||
# sigma = self.sigma
|
||||
|
||||
# m = np.random.rand(prb.mapping.nP)
|
||||
# d = np.random.rand(prb.survey.nD)
|
||||
|
||||
# V1 = d.dot(prb.Jvec(sigma, m))
|
||||
# V2 = m.dot(prb.Jtvec(sigma, d))
|
||||
# print 'AdjointTest', V1, V2
|
||||
# self.assertTrue(np.abs(V1-V2)/np.abs(V1) < 1e-6)
|
||||
V1 = d.dot(prb.Jvec(m0, m))
|
||||
V2 = m.dot(prb.Jtvec(m0, d))
|
||||
print 'AdjointTest', V1, V2
|
||||
self.assertTrue(np.abs(V1-V2)/np.abs(V1) < 1e-6)
|
||||
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# unittest.main()
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user