mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 22:03:56 +08:00
adjoint of fields deriv now returns a tuple (so you don't call derivs wrt _u, _m independently
This commit is contained in:
+8
-22
@@ -91,18 +91,7 @@ class BaseFDEMProblem(BaseEMProblem):
|
||||
for rx in src.rxList:
|
||||
df_dmFun = getattr(u, '_%sDeriv'%rx.projField, None)
|
||||
df_dm = df_dmFun(src, du_dm, v, adjoint=False)
|
||||
|
||||
# df_duFun = getattr(u, '_%sDeriv_u'%rx.projField, None)
|
||||
# df_dudu_dm = df_duFun(src, du_dm, adjoint=False)
|
||||
|
||||
# df_dmFun = getattr(u, '_%sDeriv_m'%rx.projField, None)
|
||||
# df_dm = df_dmFun(src, v, adjoint=False)
|
||||
|
||||
|
||||
Df_Dm = np.array(df_dm,dtype=complex)
|
||||
|
||||
# P = lambda v:
|
||||
|
||||
Jv[src, rx] = rx.projectFieldsDeriv(src, self.mesh, u, Df_Dm)
|
||||
|
||||
Ainv.clean()
|
||||
@@ -141,26 +130,23 @@ class BaseFDEMProblem(BaseEMProblem):
|
||||
for rx in src.rxList:
|
||||
PTv = rx.projectFieldsDeriv(src, self.mesh, u, v[src, rx], adjoint=True) # wrt u, need possibility wrt m
|
||||
|
||||
df_duTFun = getattr(u, '_%sDeriv_u'%rx.projField, None)
|
||||
df_duT = df_duTFun(src, PTv, adjoint=True)
|
||||
|
||||
df_duTFun = getattr(u, '_%sDeriv'%rx.projField, None)
|
||||
df_duT, df_dmT = df_duTFun(src, None, PTv, adjoint=True)
|
||||
|
||||
ATinvdf_duT = ATinv * df_duT
|
||||
|
||||
dA_dmT = self.getADeriv_m(freq, u_src, ATinvdf_duT, adjoint=True)
|
||||
dRHS_dmT = self.getRHSDeriv_m(freq,src, ATinvdf_duT, adjoint=True)
|
||||
dRHS_dmT = self.getRHSDeriv_m(freq, src, ATinvdf_duT, adjoint=True)
|
||||
du_dmT = -dA_dmT + dRHS_dmT
|
||||
|
||||
df_dmFun = getattr(u, '_%sDeriv_m'%rx.projField, None)
|
||||
dfT_dm = df_dmFun(src, PTv, adjoint=True)
|
||||
df_dmT += du_dmT
|
||||
|
||||
du_dmT += dfT_dm
|
||||
|
||||
# TODO: this should be taken care of by the reciever
|
||||
# TODO: this should be taken care of by the reciever?
|
||||
real_or_imag = rx.projComp
|
||||
if real_or_imag is 'real':
|
||||
Jtv += np.array(du_dmT,dtype=complex).real
|
||||
Jtv += np.array(df_dmT,dtype=complex).real
|
||||
elif real_or_imag is 'imag':
|
||||
Jtv += - np.array(du_dmT,dtype=complex).real
|
||||
Jtv += - np.array(df_dmT,dtype=complex).real
|
||||
else:
|
||||
raise Exception('Must be real or imag')
|
||||
|
||||
|
||||
@@ -89,26 +89,75 @@ class Fields(SimPEG.Problem.Fields):
|
||||
return self._jPrimary(solution, srcList) + self._jSecondary(solution, srcList)
|
||||
|
||||
def _eDeriv(self, src, du_dm, v, adjoint = False):
|
||||
"""
|
||||
Total derivative of e with respect to the inversion model. Returns :math:`d\mathbf{e}/d\mathbf{m}` for forward and (:math:`d\mathbf{e}/d\mathbf{u}`, :math:`d\mathb{u}/d\mathbf{m}`) for the adjoint
|
||||
|
||||
:param Src src: sorce
|
||||
:param numpy.ndarray du_dm: derivative of the solution vector with respect to the model times a vector (is None for adjoint)
|
||||
:param numpy.ndarray v: vector to take sensitivity product with
|
||||
:param bool adjoint: adjoint?
|
||||
:rtype: numpy.ndarray
|
||||
:return: derivative times a vector (or tuple for adjoint)
|
||||
"""
|
||||
if getattr(self, '_eDeriv_u', None) is None or getattr(self, '_eDeriv_m', None) is None:
|
||||
raise NotImplementedError ('Getting eDerivs from %s is not implemented' %self.knownFields.keys()[0])
|
||||
|
||||
if adjoint:
|
||||
return self._eDeriv_u(src, v, adjoint), self._eDeriv_m(src, v, adjoint)
|
||||
return self._eDeriv_u(src, du_dm, adjoint) + self._eDeriv_m(src, v, adjoint)
|
||||
|
||||
def _bDeriv(self, src, du_dm, v, adjoint = False):
|
||||
"""
|
||||
Total derivative of b with respect to the inversion model. Returns :math:`d\mathbf{b}/d\mathbf{m}` for forward and (:math:`d\mathbf{b}/d\mathbf{u}`, :math:`d\mathb{u}/d\mathbf{m}`) for the adjoint
|
||||
|
||||
:param Src src: sorce
|
||||
:param numpy.ndarray du_dm: derivative of the solution vector with respect to the model times a vector (is None for adjoint)
|
||||
:param numpy.ndarray v: vector to take sensitivity product with
|
||||
:param bool adjoint: adjoint?
|
||||
:rtype: numpy.ndarray
|
||||
:return: derivative times a vector (or tuple for adjoint)
|
||||
"""
|
||||
if getattr(self, '_bDeriv_u', None) is None or getattr(self, '_bDeriv_m', None) is None:
|
||||
raise NotImplementedError ('Getting bDerivs from %s is not implemented' %self.knownFields.keys()[0])
|
||||
|
||||
if adjoint:
|
||||
return self._bDeriv_u(src, v, adjoint), self._bDeriv_m(src, v, adjoint)
|
||||
return self._bDeriv_u(src, du_dm, adjoint) + self._bDeriv_m(src, v, adjoint)
|
||||
|
||||
def _hDeriv(self, src, du_dm, v, adjoint = False):
|
||||
"""
|
||||
Total derivative of h with respect to the inversion model. Returns :math:`d\mathbf{h}/d\mathbf{m}` for forward and (:math:`d\mathbf{h}/d\mathbf{u}`, :math:`d\mathb{u}/d\mathbf{m}`) for the adjoint
|
||||
|
||||
:param Src src: sorce
|
||||
:param numpy.ndarray du_dm: derivative of the solution vector with respect to the model times a vector (is None for adjoint)
|
||||
:param numpy.ndarray v: vector to take sensitivity product with
|
||||
:param bool adjoint: adjoint?
|
||||
:rtype: numpy.ndarray
|
||||
:return: derivative times a vector (or tuple for adjoint)
|
||||
"""
|
||||
if getattr(self, '_hDeriv_u', None) is None or getattr(self, '_hDeriv_m', None) is None:
|
||||
raise NotImplementedError ('Getting hDerivs from %s is not implemented' %self.knownFields.keys()[0])
|
||||
|
||||
if adjoint:
|
||||
return self._hDeriv_u(src, v, adjoint), self._hDeriv_m(src, v, adjoint)
|
||||
return self._hDeriv_u(src, du_dm, adjoint) + self._hDeriv_m(src, v, adjoint)
|
||||
|
||||
def _jDeriv(self, src, du_dm, v, adjoint = False):
|
||||
"""
|
||||
Total derivative of j with respect to the inversion model. Returns :math:`d\mathbf{j}/d\mathbf{m}` for forward and (:math:`d\mathbf{j}/d\mathbf{u}`, :math:`d\mathb{u}/d\mathbf{m}`) for the adjoint
|
||||
|
||||
:param Src src: sorce
|
||||
:param numpy.ndarray du_dm: derivative of the solution vector with respect to the model times a vector (is None for adjoint)
|
||||
:param numpy.ndarray v: vector to take sensitivity product with
|
||||
:param bool adjoint: adjoint?
|
||||
:rtype: numpy.ndarray
|
||||
:return: derivative times a vector (or tuple for adjoint)
|
||||
"""
|
||||
if getattr(self, '_jDeriv_u', None) is None or getattr(self, '_jDeriv_m', None) is None:
|
||||
raise NotImplementedError ('Getting jDerivs from %s is not implemented' %self.knownFields.keys()[0])
|
||||
|
||||
if adjoint:
|
||||
return self._jDeriv_u(src, v, adjoint), self._jDeriv_m(src, v, adjoint)
|
||||
return self._jDeriv_u(src, du_dm, adjoint) + self._jDeriv_m(src, v, adjoint)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user