From 9acaae66db2e6264a82ed01f2d288663752a852e Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Sat, 22 Mar 2014 13:22:44 -0700 Subject: [PATCH] Projections and derivatives working. --- simpegEM/FDEM/FDEM.py | 47 ++++++++------------ simpegEM/FDEM/SurveyFDEM.py | 69 +++++++++++++---------------- simpegEM/Tests/test_FDEM.py | 5 ++- simpegEM/Tests/test_FieldsObject.py | 7 ++- 4 files changed, 55 insertions(+), 73 deletions(-) diff --git a/simpegEM/FDEM/FDEM.py b/simpegEM/FDEM/FDEM.py index 6df6b03a..ab855f7c 100644 --- a/simpegEM/FDEM/FDEM.py +++ b/simpegEM/FDEM/FDEM.py @@ -110,15 +110,15 @@ class BaseProblemFDEM(Problem.BaseProblem): u_tx = u[tx, self.solType] w = self.getADeriv(freq, u_tx, v) Ainvw = solver.solve(w) - fAinvw = self._calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes) - P = lambda v: tx.projectFieldsDeriv(self.mesh, u, v) + for rx in tx.rxList: + fAinvw = self.calcFields(Ainvw, freq, rx.projField) + P = lambda v: rx.projectFieldsDeriv(tx, self.mesh, u, v) - df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v) - #TODO: this is now a list? - if df_dm is None: - Jv[tx] = - P(fAinvw) - else: - Jv[tx] = - P(fAinvw) + P(df_dm) + df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, v) + if df_dm is None: + Jv[tx, rx] = - P(fAinvw) + else: + Jv[tx, rx] = - P(fAinvw) + P(df_dm) return Utils.mkvc(Jv) @@ -141,33 +141,22 @@ class BaseProblemFDEM(Problem.BaseProblem): for tx in self.survey.getTransmitters(freq): u_tx = u[tx, self.solType] - PTv = tx.projectFieldsDeriv(self.mesh, u, v[tx], adjoint=True) - fPTv = self._calcFieldsList(PTv, freq, tx.rxList.fieldTypes, adjoint=True) + for rx in tx.rxList: + PTv = rx.projectFieldsDeriv(tx, self.mesh, u, v[tx, rx], adjoint=True) + fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True) - w = solver.solve( fPTv ) - Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True) + w = solver.solve( fPTv ) + Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True) - df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True) + df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, PTv, adjoint=True) - if df_dm is None: - Jtv += - Jtv_tx - else: - Jtv += - Jtv_tx + df_dm + if df_dm is None: + Jtv += - Jtv_tx + else: + Jtv += - Jtv_tx + df_dm return Jtv - def _calcFieldsList(self, sol, freq, fieldTypes, adjoint=False): - fVecs = range(len(fieldTypes)) - for ii, fieldType in enumerate(fieldTypes): - fVecs[ii] = self.calcFields(sol, freq, fieldType, adjoint=adjoint) - return np.concatenate(fVecs) - - def _calcFieldsDerivList(self, sol, freq, fieldTypes, v, adjoint=False): - fVecs = range(len(fieldTypes)) - V = v.reshape((-1, len(fieldTypes)), order='F') - for ii, fieldType in enumerate(fieldTypes): - fVecs[ii] = self.calcFieldsDeriv(sol, freq, fieldType, V[:,ii], adjoint=adjoint) - return np.concatenate(fVecs) class ProblemFDEM_e(BaseProblemFDEM): """ diff --git a/simpegEM/FDEM/SurveyFDEM.py b/simpegEM/FDEM/SurveyFDEM.py index 0be03e84..834f7d2b 100644 --- a/simpegEM/FDEM/SurveyFDEM.py +++ b/simpegEM/FDEM/SurveyFDEM.py @@ -61,6 +61,35 @@ class RxFDEM(Survey.BaseRx): P = self._Ps[gloc][mesh] return P + def projectFields(self, tx, mesh, u): + P = self.getP(mesh) + u_part_complex = u[tx, self.projField] + # get the real or imag component + real_or_imag = self.projComp + u_part = getattr(u_part_complex, real_or_imag) + return P*u_part + + def projectFieldsDeriv(self, tx, mesh, u, v, adjoint=False): + P = self.getP(mesh) + + if not adjoint: + Pv_complex = P * v + #TODO: check this deriv... + real_or_imag = self.projComp + Pv = getattr(Pv_complex, real_or_imag) + elif adjoint: + Pv_real = P.T * v + + real_or_imag = self.projComp + if real_or_imag == 'imag': + Pv = 1j*Pv_real + elif real_or_imag == 'real': + Pv = Pv_real.astype(complex) + else: + raise NotImplementedError('must be real or imag') + + return Pv + class TxFDEM(Survey.BaseTx): @@ -84,41 +113,6 @@ class TxFDEM(Survey.BaseTx): """Vector number of data""" return np.array([rx.nD for rx in self.rxList]) - def projectFields(self, mesh, u): - - nRt = len(self.rxList) - Pu = range(nRt) - - for ii, rx in enumerate(self.rxList): - P = rx.getP(mesh) - u_part_complex = u[self, rx.projField] - # get the real or imag component - real_or_imag = rx.projComp - u_part = getattr(u_part_complex, real_or_imag) - Pu[ii] = P*u_part - return np.concatenate(Pu) - - def projectFieldsDeriv(self, mesh, u, v, adjoint=False): - V = v.reshape((-1, len(Ps)), order='F') - Pvs = range(len(Ps)) - for ii, rx in enumerate(self.rxList): - P = rx.getP(mesh) - - if not adjoint: - Pv = Ps[ii] * V[:, ii] - elif adjoint: - Pv = Ps[ii].T * V[:, ii] - - real_or_imag = rx.projComp - if real_or_imag == 'imag': - Pvs[ii] = 1j*Pv - elif real_or_imag == 'real': - Pvs[ii] = Pv.astype(complex) - else: - raise NotImplementedError('must be real or imag') - - return np.concatenate(Pvs) - class FieldsFDEM(object): """Fancy Field Storage for a FDEM survey.""" @@ -285,8 +279,6 @@ class DataFDEM(object): indBot += rx.nD - - class SurveyFDEM(Survey.BaseSurvey): """ docstring for SurveyFDEM @@ -349,7 +341,8 @@ class SurveyFDEM(Survey.BaseSurvey): def projectFields(self, u): data = DataFDEM(self) for tx in self.txList: - data[tx] = tx.projectFields(self.mesh, u) + for rx in tx.rxList: + data[tx, rx] = rx.projectFields(tx, self.mesh, u) return data def projectFieldsDeriv(self, u): diff --git a/simpegEM/Tests/test_FDEM.py b/simpegEM/Tests/test_FDEM.py index c60446c9..4b140823 100644 --- a/simpegEM/Tests/test_FDEM.py +++ b/simpegEM/Tests/test_FDEM.py @@ -70,8 +70,9 @@ class FDEM_DerivTests_e(unittest.TestCase): m = self.sigma u = self.prb.fields(m) - vJw = v.dot(self.prb.Jvec(m, w, u=u)) - wJtv = w.dot(self.prb.Jtvec(m, v, u=u)) + vJw = np.vdot(v, self.prb.Jvec(m, w, u=u)) + wJtv = np.vdot(w, self.prb.Jtvec(m, v, u=u)) + print 'Jtvec: ', vJw - wJtv self.assertTrue(vJw - wJtv < TOL) diff --git a/simpegEM/Tests/test_FieldsObject.py b/simpegEM/Tests/test_FieldsObject.py index 64190383..e4f60911 100644 --- a/simpegEM/Tests/test_FieldsObject.py +++ b/simpegEM/Tests/test_FieldsObject.py @@ -100,17 +100,16 @@ class FieldsTest(unittest.TestCase): Txs = F.survey.getTransmitters(freq) for ii, tx in enumerate(Txs): - dat = tx.projectFields(self.mesh, F) - self.assertTrue(dat.dtype == float) - dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList)), order='F') for jj, rx in enumerate(tx.rxList): + dat = rx.projectFields(tx, self.mesh, F) + self.assertTrue(dat.dtype == float) fieldType = rx.projField u = {'b':b[:,ii], 'e': e[:,ii]}[fieldType] real_or_imag = rx.projComp u = getattr(u, real_or_imag) gloc = rx.projGLoc d = self.mesh.getInterpolationMat(self.XYZ, gloc)*u - self.assertTrue(np.all(dat[:, jj] == d)) + self.assertTrue(np.all(dat == d))