diff --git a/simpegEM/FDEM/FDEM.py b/simpegEM/FDEM/FDEM.py index 9c5a6643..6df6b03a 100644 --- a/simpegEM/FDEM/FDEM.py +++ b/simpegEM/FDEM/FDEM.py @@ -110,10 +110,10 @@ class BaseProblemFDEM(Problem.BaseProblem): u_tx = u[tx, self.solType] w = self.getADeriv(freq, u_tx, v) Ainvw = solver.solve(w) - fAinvw = self.calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes) + fAinvw = self._calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes) P = lambda v: tx.projectFieldsDeriv(self.mesh, u, v) - df_dm = self.calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v) + df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v) #TODO: this is now a list? if df_dm is None: Jv[tx] = - P(fAinvw) @@ -142,12 +142,12 @@ class BaseProblemFDEM(Problem.BaseProblem): u_tx = u[tx, self.solType] PTv = tx.projectFieldsDeriv(self.mesh, u, v[tx], adjoint=True) - fPTv = self.calcFields(PTv, freq, tx.rxList.fieldTypes, adjoint=True) + fPTv = self._calcFieldsList(PTv, freq, tx.rxList.fieldTypes, adjoint=True) w = solver.solve( fPTv ) Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True) - df_dm = self.calcFieldsDeriv(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True) + df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True) if df_dm is None: Jtv += - Jtv_tx @@ -156,6 +156,19 @@ class BaseProblemFDEM(Problem.BaseProblem): return Jtv + def _calcFieldsList(self, sol, freq, fieldTypes, adjoint=False): + fVecs = range(len(fieldTypes)) + for ii, fieldType in enumerate(fieldTypes): + fVecs[ii] = self.calcFields(sol, freq, fieldType, adjoint=adjoint) + return np.concatenate(fVecs) + + def _calcFieldsDerivList(self, sol, freq, fieldTypes, v, adjoint=False): + fVecs = range(len(fieldTypes)) + V = v.reshape((-1, len(fieldTypes)), order='F') + for ii, fieldType in enumerate(fieldTypes): + fVecs[ii] = self.calcFieldsDeriv(sol, freq, fieldType, V[:,ii], adjoint=adjoint) + return np.concatenate(fVecs) + class ProblemFDEM_e(BaseProblemFDEM): """ Solving for e! diff --git a/simpegEM/FDEM/SurveyFDEM.py b/simpegEM/FDEM/SurveyFDEM.py index e68d50fa..74a3ecaf 100644 --- a/simpegEM/FDEM/SurveyFDEM.py +++ b/simpegEM/FDEM/SurveyFDEM.py @@ -1,6 +1,6 @@ from SimPEG import Survey, Utils, np, sp -class RxListFDEM(Survey.BaseRxList): +class RxFDEM(Survey.BaseRx): knownRxTypes = { 'exr':['e', 'Ex', 'real'], @@ -19,36 +19,25 @@ class RxListFDEM(Survey.BaseRxList): } def __init__(self, locs, rxType): - Survey.BaseRxList.__init__(self, locs, rxType) + Survey.BaseRx.__init__(self, locs, rxType) self._Ps = {} - for rx in self.rxTypes: - self._Ps[self._projGLoc(rx)] = {} + self._Ps[self.projGLoc] = {} - def _projField(self, rx): + @property + def projField(self): """Field Type projection (e.g. e b ...)""" - if type(rx) is int: rx = self.rxTypes[rx] - return self.knownRxTypes[rx][0] + return self.knownRxTypes[self.rxType][0] - def _projGLoc(self, rx): + @property + def projGLoc(self): """Grid Location projection (e.g. Ex Fy ...)""" - if type(rx) is int: rx = self.rxTypes[rx] - return self.knownRxTypes[rx][1] + return self.knownRxTypes[self.rxType][1] - def _projComp(self, rx): + @property + def projComp(self): """Component projection (real/imag)""" - if type(rx) is int: rx = self.rxTypes[rx] - return self.knownRxTypes[rx][2] - - @property - def rxTypes(self): - """A list of the recieve types that are collected at this rxList locations.""" - return self.rxType.split(',') - - @property - def fieldTypes(self): - #TODO: This assumes that it has the structure ex, by ... - return [self._projField(rx) for rx in self.rxTypes] + return self.knownRxTypes[self.rxType][2] def getP(self, mesh): """ @@ -61,12 +50,10 @@ class RxListFDEM(Survey.BaseRxList): Projection matrices are stored as a nested dict, First gridLocation, then mesh. """ - P = [] - for rx in self.rxTypes: - gloc = self._projGLoc(rx) - if mesh not in self._Ps[gloc]: - self._Ps[gloc][mesh] = mesh.getInterpolationMat(self.locs, gloc) - P += [self._Ps[gloc][mesh]] + gloc = self.projGLoc + if mesh not in self._Ps[gloc]: + self._Ps[gloc][mesh] = mesh.getInterpolationMat(self.locs, gloc) + P = self._Ps[gloc][mesh] return P @@ -74,7 +61,7 @@ class TxFDEM(Survey.BaseTx): freq = None #: Frequency (float) - rxListPair = RxListFDEM + rxListPair = RxFDEM knownTxTypes = ['VMD'] @@ -85,35 +72,43 @@ class TxFDEM(Survey.BaseTx): @property def nD(self): """Number of data""" - return self.rxList.locs.shape[0]*len(self.rxList.rxTypes) + return self.vnD.sum() + + @property + def vnD(self): + """Vector number of data""" + return np.array([rx.locs.shape[0] for rx in self.rxList]) def projectFields(self, mesh, u): - nRt = len(self.rxList.rxTypes) + nRt = len(self.rxList) Pu = range(nRt) - Ps = self.rxList.getP(mesh) - for ii, rx in enumerate(self.rxList.rxTypes): - fieldType = self.rxList._projField(rx) - u_part_complex = u[self, fieldType] + for ii, rx in enumerate(self.rxList): + P = rx.getP(mesh) + u_part_complex = u[self, rx.projField] # get the real or imag component - real_or_imag = self.rxList._projComp(rx) + real_or_imag = rx.projComp u_part = getattr(u_part_complex, real_or_imag) - Pu[ii] = Ps[ii]*u_part - + Pu[ii] = P*u_part return np.concatenate(Pu) def projectFieldsDeriv(self, mesh, u, v, adjoint=False): - Ps = self.rxList.getP(mesh) V = v.reshape((-1, len(Ps)), order='F') Pvs = range(len(Ps)) - for ii, rx in enumerate(self.rxList.rxTypes): - Pv = Ps[ii] * V[:, ii] - real_or_imag = self.rxList._projComp(rx) + for ii, rx in enumerate(self.rxList): + P = rx.getP(mesh) + + if not adjoint: + Pv = Ps[ii] * V[:, ii] + elif adjoint: + Pv = Ps[ii].T * V[:, ii] + + real_or_imag = rx.projComp if real_or_imag == 'imag': - Pv = 1j*Pv + Pvs[ii] = 1j*Pv elif real_or_imag == 'real': - Pv = Pv.astype(complex) + Pvs[ii] = Pv.astype(complex) else: raise NotImplementedError('must be real or imag') diff --git a/simpegEM/Tests/test_FDEM.py b/simpegEM/Tests/test_FDEM.py index eeeb9df8..c60446c9 100644 --- a/simpegEM/Tests/test_FDEM.py +++ b/simpegEM/Tests/test_FDEM.py @@ -18,24 +18,24 @@ def getProblem(fdemType): x = np.linspace(5,10,3) XYZ = Utils.ndgrid(x,x,np.r_[0]) if fdemType == 'e': - rxList = EM.FDEM.RxListFDEM(XYZ, 'bxr,exi') + rxList = EM.FDEM.RxFDEM(XYZ, 'bxr') elif fdemType == 'b': - rxList = EM.FDEM.RxListFDEM(XYZ, 'exi') + rxList = EM.FDEM.RxFDEM(XYZ, 'exi') else: raise NotImplementedError() - Tx0 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-2, rxList) + Tx0 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-2, [rxList]) x = np.linspace(5,10,3) XYZ = Utils.ndgrid(x,x,np.r_[0]) if fdemType == 'e': - rxList = EM.FDEM.RxListFDEM(XYZ, 'eyi') + rxList = EM.FDEM.RxFDEM(XYZ, 'eyi') elif fdemType == 'b': - rxList = EM.FDEM.RxListFDEM(XYZ, 'eyr') + rxList = EM.FDEM.RxFDEM(XYZ, 'eyr') else: raise NotImplementedError() - Tx1 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-4, rxList) + Tx1 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-4, [rxList]) survey = EM.FDEM.SurveyFDEM([Tx0, Tx1]) diff --git a/simpegEM/Tests/test_FieldsObject.py b/simpegEM/Tests/test_FieldsObject.py index 27ee18f3..82703d8a 100644 --- a/simpegEM/Tests/test_FieldsObject.py +++ b/simpegEM/Tests/test_FieldsObject.py @@ -9,15 +9,15 @@ class FieldsTest(unittest.TestCase): x = np.linspace(5,10,3) XYZ = Utils.ndgrid(x,x,np.r_[0.]) txLoc = np.r_[0,0,0.] - rxList0 = EM.FDEM.RxListFDEM(XYZ, 'exi,exr,eyi,eyr,ezi,ezr') - Tx0 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., rxList0) - rxList1 = EM.FDEM.RxListFDEM(XYZ, 'bxi,bxr,byi,byr,bzi,bzr') - Tx1 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., rxList1) - rxList2 = EM.FDEM.RxListFDEM(XYZ, 'bxi,eyr') - Tx2 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., rxList2) - rxList3 = EM.FDEM.RxListFDEM(XYZ, 'bxi') - Tx3 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., rxList3) - Tx4 = EM.FDEM.TxFDEM(txLoc, 'VMD', 1., rxList0) + rxList0 = EM.FDEM.RxFDEM(XYZ, 'exi') + Tx0 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., [rxList0]) + rxList1 = EM.FDEM.RxFDEM(XYZ, 'bxi') + Tx1 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., [rxList1]) + rxList2 = EM.FDEM.RxFDEM(XYZ, 'bxi') + Tx2 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., [rxList2]) + rxList3 = EM.FDEM.RxFDEM(XYZ, 'bxi') + Tx3 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., [rxList3]) + Tx4 = EM.FDEM.TxFDEM(txLoc, 'VMD', 1., [rxList0, rxList1, rxList2, rxList3]) txList = [Tx0,Tx1,Tx2,Tx3,Tx4] survey = EM.FDEM.SurveyFDEM(txList) self.F = EM.FDEM.FieldsFDEM(mesh, survey) @@ -101,13 +101,13 @@ class FieldsTest(unittest.TestCase): for ii, tx in enumerate(Txs): dat = tx.projectFields(self.mesh, F) self.assertTrue(dat.dtype == float) - dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList.rxTypes)), order='F') - for jj, rx in enumerate(tx.rxList.rxTypes): - fieldType = tx.rxList._projField(rx) + dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList)), order='F') + for jj, rx in enumerate(tx.rxList): + fieldType = rx.projField u = {'b':b[:,ii], 'e': e[:,ii]}[fieldType] - real_or_imag = tx.rxList._projComp(rx) + real_or_imag = rx.projComp u = getattr(u, real_or_imag) - gloc = tx.rxList._projGLoc(rx) + gloc = rx.projGLoc d = self.mesh.getInterpolationMat(self.XYZ, gloc)*u self.assertTrue(np.all(dat[:, jj] == d))