mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-02 08:30:17 +08:00
cleanup and minimal docs for Jvec, JTvec
This commit is contained in:
+35
-104
@@ -134,6 +134,18 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
|
||||
def Jtvec(self, m, v, u=None):
|
||||
|
||||
"""
|
||||
Jvec computes the adjoint of the sensitivity times a vector
|
||||
|
||||
.. math::
|
||||
\mathbf{J}^\\top \mathbf{v} = \left( \\frac{d\mathbf{u}}{d\mathbf{m}} ^ \\top \\frac{d\mathbf{F}}{d\mathbf{u}} ^ \\top + \\frac{\partial\mathbf{F}}{\partial\mathbf{m}} ^ \\top \\right) \\frac{d\mathbf{P}}{d\mathbf{F}} ^ \\top \mathbf{v}
|
||||
|
||||
where
|
||||
|
||||
.. math::
|
||||
\\frac{d\mathbf{u}}{d\mathbf{m}} ^\\top \mathbf{A}^\\top + \\frac{d\mathbf{A}(\mathbf{u})}{d\mathbf{m}} ^ \\top = \\frac{d \mathbf{RHS}}{d \mathbf{m}} ^ \\top
|
||||
"""
|
||||
|
||||
if u is None:
|
||||
u = self.fields(m)
|
||||
|
||||
@@ -144,24 +156,20 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
if not isinstance(v, self.dataPair):
|
||||
v = self.dataPair(self.survey, v)
|
||||
|
||||
# TODO: make this general
|
||||
# if self._fieldType is 'b':
|
||||
# dun_dmT_v = np.zeros((len(m), self.survey.nSrc))
|
||||
|
||||
|
||||
PT_v = Fields_Derivs(self.mesh, self.survey) #PT_v is a fields object
|
||||
|
||||
|
||||
# TODO: This will only work for b formulation right now b/c of the mesh.nF
|
||||
df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
if ftype is 'bSolution' or 'jSolution':
|
||||
df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
|
||||
elif ftype is 'eSolution' or 'hSolution':
|
||||
df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
ATinv_df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
|
||||
|
||||
JTv = np.zeros(m.shape)
|
||||
|
||||
|
||||
# TODO : this is pretty ugly
|
||||
|
||||
# Loop over sources and receivers to create a fields object: PT_v
|
||||
# Loop over sources and receivers to create a fields object: PT_v, df_duT_v, df_dmT_v
|
||||
for src in self.survey.srcList:
|
||||
# initialize empty fields derivs
|
||||
for projField in set([rx.projField for rx in src.rxList]):
|
||||
@@ -173,8 +181,6 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
curPT_v = rx.evalDeriv(src, self.mesh, self.timeMesh, Utils.mkvc(v[src,rx]), adjoint=True)
|
||||
PT_v[src,'%sDeriv'%rx.projField, :] += np.reshape(curPT_v,(len(curPT_v)/self.timeMesh.nN, self.timeMesh.nN), order='F') # All the fields for a given src, reciever.
|
||||
|
||||
# print np.linalg.norm(PT_v[src,'bDeriv',:])
|
||||
# for src in self.survey.srcList:
|
||||
# initialize empty fields derivs
|
||||
for projField in set([rx.projField for rx in src.rxList]):
|
||||
df_duTFun = getattr(u, '_%sDeriv'%projField, None)
|
||||
@@ -184,19 +190,16 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
JTv = JTv + df_dmT_v
|
||||
df_duT_v = df_duT_v + df_duT_v_cur
|
||||
|
||||
# print np.linalg.norm(df_duT_v)
|
||||
|
||||
AdiagTinv = None
|
||||
|
||||
|
||||
# for tInd in reversed(range(self.nT)): #enumerate(reversed(list(self.timeSteps))):
|
||||
for tInd in reversed(range(self.nT)) : # reversed(self.timeSteps)):
|
||||
if AdiagTinv is not None: # and (tInd <= self.nT and dt != self.timeSteps[tInd]):
|
||||
# (tInd < self.nT and self.timeSteps[tInd] != self.timeSteps[tInd + 1]):# keep factors if dt is the same as previous step b/c A will be the same
|
||||
# Do the back-solve through time
|
||||
for tInd in reversed(range(self.nT)):
|
||||
if AdiagTinv is not None and (tInd <= self.nT and self.timeSteps[tInd] != self.timeSteps[tInd+1]): # if the previous timestep is the same --> no need to refactor the matrix
|
||||
AdiagTinv.clean()
|
||||
AdiagTinv = None
|
||||
|
||||
|
||||
# refactor if we need to
|
||||
if AdiagTinv is None:
|
||||
Adiag = self.getAdiag(tInd)
|
||||
AdiagTinv = self.Solver(Adiag.T, **self.solverOpts)
|
||||
@@ -208,73 +211,24 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
Asubdiag = self.getAsubdiag(tInd+1)
|
||||
ATinv_df_duT_v[:,tInd+1] = AdiagTinv * (df_duT_v[:,tInd+1] - Asubdiag.T * ATinv_df_duT_v[:,tInd+2])
|
||||
|
||||
# for src in self.survey.srcList:
|
||||
un_src = u[src,ftype,tInd+1]
|
||||
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # cell centered on time mesh
|
||||
for src in self.survey.srcList:
|
||||
un_src = u[src,ftype,tInd+1]
|
||||
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # cell centered on time mesh
|
||||
|
||||
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # on nodes of time mesh
|
||||
# dAsubdiag_dm_v = 0
|
||||
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # on nodes of time mesh
|
||||
# dAsubdiag_dm_v = 0
|
||||
|
||||
JTv = JTv + (-dAT_dm_v + dRHST_dm_v)
|
||||
|
||||
# JTv = JTv +
|
||||
|
||||
# tInd = 0
|
||||
# un_src = u[src,ftype,tInd]
|
||||
# # dAT_dm_v = self.getAdiagDeriv(None, un_src, self.getInitialFieldsDeriv(), adjoint=True)
|
||||
# Asubdiag = self.getAsubdiag(tInd)
|
||||
# ATinv_df_duT_v[:,tInd] = AdiagTinv * (- Asubdiag.T * df_duT_v[:,tInd])
|
||||
# # - self.getAsubdiag(tInd).T * df_duT_v[:,tInd]
|
||||
# dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd], adjoint=True) # cell centered on time mesh
|
||||
|
||||
# dRHST_dm_v = self.getRHSDeriv(tInd, src, ATinv_df_duT_v[:,tInd], adjoint=True)
|
||||
|
||||
# JTv = JTv + (- dAT_dm_v + dRHST_dm_v)
|
||||
JTv = JTv + (-dAT_dm_v + dRHST_dm_v)
|
||||
|
||||
# this doesn't include initial fields deriv
|
||||
|
||||
# adding du_dm^T * dF_du^T * P^T vfor time 0 (no dRHS_dm_v at time 0)
|
||||
JTv = JTv + self.getInitialFieldsDeriv(df_duT_v[:,0], adjoint=True)
|
||||
|
||||
return JTv
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# for i, src in enumerate(self.survey.srcList):
|
||||
|
||||
# un_src = u[src,ftype,tInd+1] # fields for this source at tInd
|
||||
|
||||
# for rx in src.rxList:
|
||||
|
||||
# df_duTFun = getattr(u, '_%sDeriv'%rx.projField, None)
|
||||
# df_duT_v, df_dmT_v = df_duTFun(tInd, src, None, PT_v[src,'%sDeriv'%rx.projField,tInd], adjoint=True)
|
||||
|
||||
# ATinv_df_duT_v = AdiagTinv * df_duT_v
|
||||
|
||||
# dAT_dm_v = self.getAdiagDeriv(tInd, un_src, ATinv_df_duT_v, adjoint=True)
|
||||
# dRHST_dm_v = self.getRHSDeriv(tInd, src, ATinv_df_duT_v)
|
||||
# # dAsubdiagT_dm_v = 0
|
||||
|
||||
# print dAT_dm_v.shape, Asubdiag.shape, ATinv_df_duT_v.shape, dun_dmT_v.shape,
|
||||
# dun_dmT_v[:,i] = (dRHST_dm_v - dAT_dm_v - Asubdiag.T*dun_dmT_v[:,i])
|
||||
|
||||
# # rhsT_v = self.getJRHS(tInd, src, u_src, ATinv_df_duT_v, dun_dmT_v[:,i], adjoint = True)
|
||||
|
||||
# JTv = JTv + dun_dmT_v
|
||||
|
||||
# return Utils.mkvc(JTv)
|
||||
|
||||
|
||||
|
||||
# def getJRHS(self, tInd, src, u, v, adjoint = False):
|
||||
|
||||
# dA_dm = self.getADeriv(tInd, u, v, adjoint)
|
||||
# dRHS_dm = self.getRHSDeriv(tInd, src, v, adjoint)
|
||||
|
||||
# b = - dA_dm + dRHS_dm
|
||||
|
||||
# return b
|
||||
|
||||
|
||||
def getSourceTerm(self, tInd):
|
||||
|
||||
Srcs = self.survey.srcList
|
||||
@@ -319,6 +273,9 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
|
||||
for i,src in enumerate(Srcs):
|
||||
ifieldsDeriv[:,i] = ifieldsDeriv[:,i] + getattr(src, '%sInitialDeriv'%self._fieldType, None)(self,v,adjoint)
|
||||
|
||||
if adjoint is True:
|
||||
ifieldsDeriv = ifieldsDeriv.sum()
|
||||
|
||||
return ifieldsDeriv
|
||||
|
||||
|
||||
@@ -470,32 +427,6 @@ class Problem_b(BaseTDEMProblem):
|
||||
return RHSDeriv
|
||||
|
||||
|
||||
@Utils.timeIt
|
||||
def getJdiags(self, tInd, adjoint = False):
|
||||
# The matrix that we are computing has the form:
|
||||
#
|
||||
# - - - - - -
|
||||
# | Adiag | | uderiv1 | | b1 |
|
||||
# | Asub Adiag | | uderiv2 | | b2 |
|
||||
# | Asub Adiag | | uderiv3 | = | b3 |
|
||||
# | ... ... | | ... | | .. |
|
||||
# | Asub Adiag | | uderivn | | bn |
|
||||
# - - - - - -
|
||||
|
||||
if adjoint:
|
||||
raise NotImplementedError
|
||||
|
||||
dt = self.timeSteps[tInd]
|
||||
|
||||
Adiag = self.getA(tInd)
|
||||
Asub = - 1./dt * Utils.speye(self.mesh.nF)
|
||||
|
||||
if self._makeASymmetric:
|
||||
Asub = self.MfMui.T * Asub
|
||||
|
||||
return Adiag, Asub
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user