diff --git a/simpegEM/FDEM/FDEM.py b/simpegEM/FDEM/FDEM.py index 83fc9ee5..75e880ac 100644 --- a/simpegEM/FDEM/FDEM.py +++ b/simpegEM/FDEM/FDEM.py @@ -44,16 +44,16 @@ class BaseFDEMProblem(BaseEMProblem): for freq in self.survey.freqs: A = self.getA(freq) - dF_duI = self.Solver(A, **self.solverOpts) + Ainv = self.Solver(A, **self.solverOpts) for src in self.survey.getSrcByFreq(freq): u_src = u[src, self._fieldType] dF_dm = self.getADeriv(freq, u_src, v) dRHS_dm = self.getRHSDeriv(src, v) if dRHS_dm is None: - du_dm = dF_duI * ( - dF_dm ) + du_dm = Ainv * ( - dF_dm ) else: - du_dm = dF_duI * ( - dF_dm + dRHS_dm ) + du_dm = Ainv * ( - dF_dm + dRHS_dm ) for rx in src.rxList: dAl_duFun = getattr(u, '_%sDeriv_u'%rx.projField, None) dAl_du = dAl_duFun(src, du_dm, adjoint=False) @@ -69,15 +69,6 @@ class BaseFDEMProblem(BaseEMProblem): Jv[src, rx] = P(du_dm) - # fAinvw = self.calcFields(Ainvw, freq, rx.projField) - # P = lambda v: rx.projectFieldsDeriv(src, self.mesh, u, v) - - # Jv[src, rx] = - P(fAinvw) - - # df_dm = self.calcFieldsDeriv(u_src, freq, rx.projField, v) - # if df_dm is not None: - # Jv[src, rx] += P(df_dm) - return Utils.mkvc(Jv) def Jtvec(self, m, v, u=None): @@ -90,32 +81,56 @@ class BaseFDEMProblem(BaseEMProblem): if not isinstance(v, self.dataPair): v = self.dataPair(self.survey, v) - Jtv = np.zeros(self.mapping.nP) + # Jtv = np.zeros(self.PropMap.PropModel.nP) + Jtv = np.zeros(m.size) for freq in self.survey.freqs: AT = self.getA(freq).T ATinv = self.Solver(AT, **self.solverOpts) - for src in self.survey.getSource(freq): - u_src = u[src, self.solType] + for src in self.survey.getSrcByFreq(freq): + u_src = u[src, self._fieldType] for rx in src.rxList: PTv = rx.projectFieldsDeriv(src, self.mesh, u, v[src, rx], adjoint=True) - fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True) - w = ATinv * fPTv - Jtv_rx = - self.getADeriv(freq, u_src, w, adjoint=True) + dAl_duTFun = getattr(u, '_%sDeriv_u'%rx.projField, None) + dAl_duT = dAl_duTFun(src, PTv, adjoint=True) + if dAl_duT is not None: + dF_duIT = ATinv * dAl_duT + else: + dF_duIT = ATinv * PTv - df_dm = self.calcFieldsDeriv(u_src, freq, rx.projField, PTv, adjoint=True) + dF_dmT = self.getADeriv(freq, u_src, dF_duIT, adjoint=True) - if df_dm is not None: - Jtv_rx += df_dm + dRHS_dmT = self.getRHSDeriv(src, dF_duIT, adjoint=True) + + if dRHS_dmT is None: + du_dmT = - dF_dmT + else: + du_dmT = -dF_dmT + dRHS_dmT + + dAl_dmFun = getattr(u, '_%sDeriv_m'%rx.projField, None) + dAlT_dm = dAl_dmFun(src, PTv, adjoint=True) + if dAlT_dm is not None: + du_dmT += dAlT_dm + + + # fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True) + + # w = ATinv * fPTv + # Jtv_rx = - self.getADeriv(freq, u_src, w, adjoint=True) + + # df_dm = self.calcFieldsDeriv(u_src, freq, rx.projField, PTv, adjoint=True) + + # if df_dm is not None: + # Jtv_rx += df_dm real_or_imag = rx.projComp if real_or_imag == 'real': - Jtv += Jtv_rx.real + Jtv += du_dmT.real elif real_or_imag == 'imag': - Jtv += - Jtv_rx.real + Jtv += - du_dmT.real else: raise Exception('Must be real or imag') @@ -220,19 +235,35 @@ class ProblemFDEM_e(BaseFDEMProblem): return RHS def getRHSDeriv(self, src, v, adjoint=False): - S_mDeriv, S_eDeriv = src.evalDeriv(self, v, adjoint) - if adjoint: - # evalDeriv(MfMui.T* C * v, adjoint = True) - raise Exception('Not implemented') + C = self.mesh.edgeCurl + MfMui = self.MfMui + S_mDeriv, S_eDeriv = src.evalDeriv(self, adjoint) + # # evalDeriv(MfMui.T* (C * v), adjoint) + # raise Exception('Not implemented') - if S_mDeriv is not None and S_eDeriv is not None: - return C.T * (MfMui * S_mDeriv) -1j*omega(freq)*S_eDeriv - elif S_mDeriv is not None: - return C.T * (MfMui * S_mDeriv) - elif S_eDeriv is not None: - return -1j*omega(freq)*S_eDeriv - else: - return None + if adjoint: + dRHS = MfMui * (C * v) + S_mDerivv = S_mDeriv(dRHS) + S_eDerivv = S_eDeriv(v) + if S_mDerivv is not None and S_eDerivv is not None: + return S_mDerivv - 1j*omega(freq)*S_eDerivv + elif S_mDerivv is not None: + return S_mDerivv + elif S_eDerivv is not None: + return - 1j*omega(freq)*S_eDerivv + else: + return None + else: + S_mDerivv, S_eDerivv = S_mDeriv(v), S_eDeriv(v) + + if S_mDerivv is not None and S_eDerivv is not None: + return C.T * (MfMui * S_mDerivv) -1j*omega(freq)*S_eDerivv + elif S_mDerivv is not None: + return C.T * (MfMui * S_mDerivv) + elif S_eDerivv is not None: + return -1j*omega(freq)*S_eDerivv + else: + return None class ProblemFDEM_b(BaseFDEMProblem): diff --git a/simpegEM/FDEM/FieldsFDEM.py b/simpegEM/FDEM/FieldsFDEM.py index 8844b3ce..c98edba8 100644 --- a/simpegEM/FDEM/FieldsFDEM.py +++ b/simpegEM/FDEM/FieldsFDEM.py @@ -70,7 +70,8 @@ class FieldsFDEM_e(FieldsFDEM): return - 1./(1j*omega(src.freq)) * (C * v) def _bSecondaryDeriv_m(self, src, v, adjoint = False): - S_mDeriv, _ = src.evalDeriv(self.survey.prob, v, adjoint) + S_mDeriv, _ = src.evalDeriv(self.survey.prob, adjoint) + S_mDeriv = S_mDeriv(v) if S_mDeriv is not None: return 1./(1j * omega(src.freq)) * S_mDeriv return None @@ -142,7 +143,8 @@ class FieldsFDEM_b(FieldsFDEM): def _eDeriv(self, bSolution, srcList, v, adjoint=False): raise NotImplementedError('Fields Derivs Not Implemented Yet') - _,S_eDeriv = src.evalDeriv(self.survey.prob, v, adjoint) + _,S_eDeriv = src.evalDeriv(self.survey.prob, adjoint) + S_eDeriv = S_eDeriv(v) if S_eDeriv is None: return None diff --git a/simpegEM/FDEM/SurveyFDEM.py b/simpegEM/FDEM/SurveyFDEM.py index 84a278c1..47749ef0 100644 --- a/simpegEM/FDEM/SurveyFDEM.py +++ b/simpegEM/FDEM/SurveyFDEM.py @@ -101,7 +101,7 @@ class SrcFDEM(Survey.BaseSrc): return S_m, S_e def evalDeriv(self, prob, v, adjoint=False): - return self.S_mDeriv(prob,v,adjoint), self.S_eDeriv(prob,v,adjoint) + return lambda v: self.S_mDeriv(prob,v,adjoint), lambda v: self.S_eDeriv(prob,v,adjoint) def bPrimary(self,prob): return None diff --git a/simpegEM/Tests/test_FDEM.py b/simpegEM/Tests/test_FDEM.py index fe9e69ed..b9121b6f 100644 --- a/simpegEM/Tests/test_FDEM.py +++ b/simpegEM/Tests/test_FDEM.py @@ -7,7 +7,7 @@ import copy testDerivs = True testCrossCheck = False -testAdjoint = False +testAdjoint = True testEB = True testHJ = False @@ -81,7 +81,8 @@ def adjointTest(fdemType, comp): u = prb.fields(m) v = np.random.rand(survey.nD) - w = np.random.rand(prb.mapping.nP) + # print prb.PropMap.PropModel.nP + w = np.random.rand(prb.mesh.nC) vJw = v.dot(prb.Jvec(m, w, u=u)) wJtv = w.dot(prb.Jtvec(m, v, u=u)) @@ -271,53 +272,53 @@ class FDEM_DerivTests(unittest.TestCase): if testEB: def test_Jtvec_adjointTest_exr_Eform(self): self.assertTrue(adjointTest('e', 'exr')) - def test_Jtvec_adjointTest_exr_Bform(self): - self.assertTrue(adjointTest('b', 'exr')) + # def test_Jtvec_adjointTest_exr_Bform(self): + # self.assertTrue(adjointTest('b', 'exr')) def test_Jtvec_adjointTest_eyr_Eform(self): self.assertTrue(adjointTest('e', 'eyr')) - def test_Jtvec_adjointTest_eyr_Bform(self): - self.assertTrue(adjointTest('b', 'eyr')) + # def test_Jtvec_adjointTest_eyr_Bform(self): + # self.assertTrue(adjointTest('b', 'eyr')) def test_Jtvec_adjointTest_ezr_Eform(self): self.assertTrue(adjointTest('e', 'ezr')) - def test_Jtvec_adjointTest_ezr_Bform(self): - self.assertTrue(adjointTest('b', 'ezr')) + # def test_Jtvec_adjointTest_ezr_Bform(self): + # self.assertTrue(adjointTest('b', 'ezr')) def test_Jtvec_adjointTest_exi_Eform(self): self.assertTrue(adjointTest('e', 'exi')) - def test_Jtvec_adjointTest_exi_Bform(self): - self.assertTrue(adjointTest('b', 'exi')) + # def test_Jtvec_adjointTest_exi_Bform(self): + # self.assertTrue(adjointTest('b', 'exi')) def test_Jtvec_adjointTest_eyi_Eform(self): self.assertTrue(adjointTest('e', 'eyi')) - def test_Jtvec_adjointTest_eyi_Bform(self): - self.assertTrue(adjointTest('b', 'eyi')) + # def test_Jtvec_adjointTest_eyi_Bform(self): + # self.assertTrue(adjointTest('b', 'eyi')) def test_Jtvec_adjointTest_ezi_Eform(self): self.assertTrue(adjointTest('e', 'ezi')) - def test_Jtvec_adjointTest_ezi_Bform(self): - self.assertTrue(adjointTest('b', 'ezi')) + # def test_Jtvec_adjointTest_ezi_Bform(self): + # self.assertTrue(adjointTest('b', 'ezi')) def test_Jtvec_adjointTest_bxr_Eform(self): self.assertTrue(adjointTest('e', 'bxr')) - def test_Jtvec_adjointTest_bxr_Bform(self): - self.assertTrue(adjointTest('b', 'bxr')) + # def test_Jtvec_adjointTest_bxr_Bform(self): + # self.assertTrue(adjointTest('b', 'bxr')) def test_Jtvec_adjointTest_byr_Eform(self): self.assertTrue(adjointTest('e', 'byr')) - def test_Jtvec_adjointTest_byr_Bform(self): - self.assertTrue(adjointTest('b', 'byr')) + # def test_Jtvec_adjointTest_byr_Bform(self): + # self.assertTrue(adjointTest('b', 'byr')) def test_Jtvec_adjointTest_bzr_Eform(self): self.assertTrue(adjointTest('e', 'bzr')) - def test_Jtvec_adjointTest_bzr_Bform(self): - self.assertTrue(adjointTest('b', 'bzr')) + # def test_Jtvec_adjointTest_bzr_Bform(self): + # self.assertTrue(adjointTest('b', 'bzr')) def test_Jtvec_adjointTest_bxi_Eform(self): self.assertTrue(adjointTest('e', 'bxi')) - def test_Jtvec_adjointTest_bxi_Bform(self): - self.assertTrue(adjointTest('b', 'bxi')) + # def test_Jtvec_adjointTest_bxi_Bform(self): + # self.assertTrue(adjointTest('b', 'bxi')) def test_Jtvec_adjointTest_byi_Eform(self): self.assertTrue(adjointTest('e', 'byi')) - def test_Jtvec_adjointTest_byi_Bform(self): - self.assertTrue(adjointTest('b', 'byi')) + # def test_Jtvec_adjointTest_byi_Bform(self): + # self.assertTrue(adjointTest('b', 'byi')) def test_Jtvec_adjointTest_bzi_Eform(self): self.assertTrue(adjointTest('e', 'bzi')) - def test_Jtvec_adjointTest_bzi_Bform(self): - self.assertTrue(adjointTest('b', 'bzi')) + # def test_Jtvec_adjointTest_bzi_Bform(self): + # self.assertTrue(adjointTest('b', 'bzi')) if testHJ: def test_Jtvec_adjointTest_jxr_Jform(self):