cleanup and minimal docs for Jvec, JTvec

This commit is contained in:
Lindsey Heagy
2016-03-13 11:09:11 -07:00
parent a1ecef0709
commit c708ceb53d
+35 -104
View File
@@ -134,6 +134,18 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
def Jtvec(self, m, v, u=None):
"""
Jvec computes the adjoint of the sensitivity times a vector
.. math::
\mathbf{J}^\\top \mathbf{v} = \left( \\frac{d\mathbf{u}}{d\mathbf{m}} ^ \\top \\frac{d\mathbf{F}}{d\mathbf{u}} ^ \\top + \\frac{\partial\mathbf{F}}{\partial\mathbf{m}} ^ \\top \\right) \\frac{d\mathbf{P}}{d\mathbf{F}} ^ \\top \mathbf{v}
where
.. math::
\\frac{d\mathbf{u}}{d\mathbf{m}} ^\\top \mathbf{A}^\\top + \\frac{d\mathbf{A}(\mathbf{u})}{d\mathbf{m}} ^ \\top = \\frac{d \mathbf{RHS}}{d \mathbf{m}} ^ \\top
"""
if u is None:
u = self.fields(m)
@@ -144,24 +156,20 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
if not isinstance(v, self.dataPair):
v = self.dataPair(self.survey, v)
# TODO: make this general
# if self._fieldType is 'b':
# dun_dmT_v = np.zeros((len(m), self.survey.nSrc))
PT_v = Fields_Derivs(self.mesh, self.survey) #PT_v is a fields object
# TODO: This will only work for b formulation right now b/c of the mesh.nF
df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
if ftype is 'bSolution' or 'jSolution':
df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
ATinv_df_duT_v = np.zeros((self.mesh.nF,self.nT+1))
elif ftype is 'eSolution' or 'hSolution':
df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
ATinv_df_duT_v = np.zeros((self.mesh.nE,self.nT+1))
JTv = np.zeros(m.shape)
# TODO : this is pretty ugly
# Loop over sources and receivers to create a fields object: PT_v
# Loop over sources and receivers to create a fields object: PT_v, df_duT_v, df_dmT_v
for src in self.survey.srcList:
# initialize empty fields derivs
for projField in set([rx.projField for rx in src.rxList]):
@@ -173,8 +181,6 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
curPT_v = rx.evalDeriv(src, self.mesh, self.timeMesh, Utils.mkvc(v[src,rx]), adjoint=True)
PT_v[src,'%sDeriv'%rx.projField, :] += np.reshape(curPT_v,(len(curPT_v)/self.timeMesh.nN, self.timeMesh.nN), order='F') # All the fields for a given src, reciever.
# print np.linalg.norm(PT_v[src,'bDeriv',:])
# for src in self.survey.srcList:
# initialize empty fields derivs
for projField in set([rx.projField for rx in src.rxList]):
df_duTFun = getattr(u, '_%sDeriv'%projField, None)
@@ -184,19 +190,16 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
JTv = JTv + df_dmT_v
df_duT_v = df_duT_v + df_duT_v_cur
# print np.linalg.norm(df_duT_v)
AdiagTinv = None
# for tInd in reversed(range(self.nT)): #enumerate(reversed(list(self.timeSteps))):
for tInd in reversed(range(self.nT)) : # reversed(self.timeSteps)):
if AdiagTinv is not None: # and (tInd <= self.nT and dt != self.timeSteps[tInd]):
# (tInd < self.nT and self.timeSteps[tInd] != self.timeSteps[tInd + 1]):# keep factors if dt is the same as previous step b/c A will be the same
# Do the back-solve through time
for tInd in reversed(range(self.nT)):
if AdiagTinv is not None and (tInd <= self.nT and self.timeSteps[tInd] != self.timeSteps[tInd+1]): # if the previous timestep is the same --> no need to refactor the matrix
AdiagTinv.clean()
AdiagTinv = None
# refactor if we need to
if AdiagTinv is None:
Adiag = self.getAdiag(tInd)
AdiagTinv = self.Solver(Adiag.T, **self.solverOpts)
@@ -208,73 +211,24 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
Asubdiag = self.getAsubdiag(tInd+1)
ATinv_df_duT_v[:,tInd+1] = AdiagTinv * (df_duT_v[:,tInd+1] - Asubdiag.T * ATinv_df_duT_v[:,tInd+2])
# for src in self.survey.srcList:
un_src = u[src,ftype,tInd+1]
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # cell centered on time mesh
for src in self.survey.srcList:
un_src = u[src,ftype,tInd+1]
dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # cell centered on time mesh
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # on nodes of time mesh
# dAsubdiag_dm_v = 0
dRHST_dm_v = self.getRHSDeriv(tInd+1, src, ATinv_df_duT_v[:,tInd+1], adjoint=True) # on nodes of time mesh
# dAsubdiag_dm_v = 0
JTv = JTv + (-dAT_dm_v + dRHST_dm_v)
# JTv = JTv +
# tInd = 0
# un_src = u[src,ftype,tInd]
# # dAT_dm_v = self.getAdiagDeriv(None, un_src, self.getInitialFieldsDeriv(), adjoint=True)
# Asubdiag = self.getAsubdiag(tInd)
# ATinv_df_duT_v[:,tInd] = AdiagTinv * (- Asubdiag.T * df_duT_v[:,tInd])
# # - self.getAsubdiag(tInd).T * df_duT_v[:,tInd]
# dAT_dm_v = self.getAdiagDeriv(None, un_src, ATinv_df_duT_v[:,tInd], adjoint=True) # cell centered on time mesh
# dRHST_dm_v = self.getRHSDeriv(tInd, src, ATinv_df_duT_v[:,tInd], adjoint=True)
# JTv = JTv + (- dAT_dm_v + dRHST_dm_v)
JTv = JTv + (-dAT_dm_v + dRHST_dm_v)
# this doesn't include initial fields deriv
# adding du_dm^T * dF_du^T * P^T vfor time 0 (no dRHS_dm_v at time 0)
JTv = JTv + self.getInitialFieldsDeriv(df_duT_v[:,0], adjoint=True)
return JTv
# for i, src in enumerate(self.survey.srcList):
# un_src = u[src,ftype,tInd+1] # fields for this source at tInd
# for rx in src.rxList:
# df_duTFun = getattr(u, '_%sDeriv'%rx.projField, None)
# df_duT_v, df_dmT_v = df_duTFun(tInd, src, None, PT_v[src,'%sDeriv'%rx.projField,tInd], adjoint=True)
# ATinv_df_duT_v = AdiagTinv * df_duT_v
# dAT_dm_v = self.getAdiagDeriv(tInd, un_src, ATinv_df_duT_v, adjoint=True)
# dRHST_dm_v = self.getRHSDeriv(tInd, src, ATinv_df_duT_v)
# # dAsubdiagT_dm_v = 0
# print dAT_dm_v.shape, Asubdiag.shape, ATinv_df_duT_v.shape, dun_dmT_v.shape,
# dun_dmT_v[:,i] = (dRHST_dm_v - dAT_dm_v - Asubdiag.T*dun_dmT_v[:,i])
# # rhsT_v = self.getJRHS(tInd, src, u_src, ATinv_df_duT_v, dun_dmT_v[:,i], adjoint = True)
# JTv = JTv + dun_dmT_v
# return Utils.mkvc(JTv)
# def getJRHS(self, tInd, src, u, v, adjoint = False):
# dA_dm = self.getADeriv(tInd, u, v, adjoint)
# dRHS_dm = self.getRHSDeriv(tInd, src, v, adjoint)
# b = - dA_dm + dRHS_dm
# return b
def getSourceTerm(self, tInd):
Srcs = self.survey.srcList
@@ -319,6 +273,9 @@ class BaseTDEMProblem(Problem.BaseTimeProblem, BaseEMProblem):
for i,src in enumerate(Srcs):
ifieldsDeriv[:,i] = ifieldsDeriv[:,i] + getattr(src, '%sInitialDeriv'%self._fieldType, None)(self,v,adjoint)
if adjoint is True:
ifieldsDeriv = ifieldsDeriv.sum()
return ifieldsDeriv
@@ -470,32 +427,6 @@ class Problem_b(BaseTDEMProblem):
return RHSDeriv
@Utils.timeIt
def getJdiags(self, tInd, adjoint = False):
# The matrix that we are computing has the form:
#
# - - - - - -
# | Adiag | | uderiv1 | | b1 |
# | Asub Adiag | | uderiv2 | | b2 |
# | Asub Adiag | | uderiv3 | = | b3 |
# | ... ... | | ... | | .. |
# | Asub Adiag | | uderivn | | bn |
# - - - - - -
if adjoint:
raise NotImplementedError
dt = self.timeSteps[tInd]
Adiag = self.getA(tInd)
Asub = - 1./dt * Utils.speye(self.mesh.nF)
if self._makeASymmetric:
Asub = self.MfMui.T * Asub
return Adiag, Asub