Change rxList --> [rx,rx]

Projections are not shared at the moment, but this can be changed later.
This commit is contained in:
rowanc1
2014-03-22 12:24:54 -07:00
parent a8f09ee869
commit aacbc03cf6
4 changed files with 77 additions and 69 deletions
+17 -4
View File
@@ -110,10 +110,10 @@ class BaseProblemFDEM(Problem.BaseProblem):
u_tx = u[tx, self.solType]
w = self.getADeriv(freq, u_tx, v)
Ainvw = solver.solve(w)
fAinvw = self.calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes)
fAinvw = self._calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes)
P = lambda v: tx.projectFieldsDeriv(self.mesh, u, v)
df_dm = self.calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v)
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v)
#TODO: this is now a list?
if df_dm is None:
Jv[tx] = - P(fAinvw)
@@ -142,12 +142,12 @@ class BaseProblemFDEM(Problem.BaseProblem):
u_tx = u[tx, self.solType]
PTv = tx.projectFieldsDeriv(self.mesh, u, v[tx], adjoint=True)
fPTv = self.calcFields(PTv, freq, tx.rxList.fieldTypes, adjoint=True)
fPTv = self._calcFieldsList(PTv, freq, tx.rxList.fieldTypes, adjoint=True)
w = solver.solve( fPTv )
Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True)
df_dm = self.calcFieldsDeriv(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True)
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True)
if df_dm is None:
Jtv += - Jtv_tx
@@ -156,6 +156,19 @@ class BaseProblemFDEM(Problem.BaseProblem):
return Jtv
def _calcFieldsList(self, sol, freq, fieldTypes, adjoint=False):
fVecs = range(len(fieldTypes))
for ii, fieldType in enumerate(fieldTypes):
fVecs[ii] = self.calcFields(sol, freq, fieldType, adjoint=adjoint)
return np.concatenate(fVecs)
def _calcFieldsDerivList(self, sol, freq, fieldTypes, v, adjoint=False):
fVecs = range(len(fieldTypes))
V = v.reshape((-1, len(fieldTypes)), order='F')
for ii, fieldType in enumerate(fieldTypes):
fVecs[ii] = self.calcFieldsDeriv(sol, freq, fieldType, V[:,ii], adjoint=adjoint)
return np.concatenate(fVecs)
class ProblemFDEM_e(BaseProblemFDEM):
"""
Solving for e!
+40 -45
View File
@@ -1,6 +1,6 @@
from SimPEG import Survey, Utils, np, sp
class RxListFDEM(Survey.BaseRxList):
class RxFDEM(Survey.BaseRx):
knownRxTypes = {
'exr':['e', 'Ex', 'real'],
@@ -19,36 +19,25 @@ class RxListFDEM(Survey.BaseRxList):
}
def __init__(self, locs, rxType):
Survey.BaseRxList.__init__(self, locs, rxType)
Survey.BaseRx.__init__(self, locs, rxType)
self._Ps = {}
for rx in self.rxTypes:
self._Ps[self._projGLoc(rx)] = {}
self._Ps[self.projGLoc] = {}
def _projField(self, rx):
@property
def projField(self):
"""Field Type projection (e.g. e b ...)"""
if type(rx) is int: rx = self.rxTypes[rx]
return self.knownRxTypes[rx][0]
return self.knownRxTypes[self.rxType][0]
def _projGLoc(self, rx):
@property
def projGLoc(self):
"""Grid Location projection (e.g. Ex Fy ...)"""
if type(rx) is int: rx = self.rxTypes[rx]
return self.knownRxTypes[rx][1]
return self.knownRxTypes[self.rxType][1]
def _projComp(self, rx):
@property
def projComp(self):
"""Component projection (real/imag)"""
if type(rx) is int: rx = self.rxTypes[rx]
return self.knownRxTypes[rx][2]
@property
def rxTypes(self):
"""A list of the recieve types that are collected at this rxList locations."""
return self.rxType.split(',')
@property
def fieldTypes(self):
#TODO: This assumes that it has the structure ex, by ...
return [self._projField(rx) for rx in self.rxTypes]
return self.knownRxTypes[self.rxType][2]
def getP(self, mesh):
"""
@@ -61,12 +50,10 @@ class RxListFDEM(Survey.BaseRxList):
Projection matrices are stored as a nested dict,
First gridLocation, then mesh.
"""
P = []
for rx in self.rxTypes:
gloc = self._projGLoc(rx)
if mesh not in self._Ps[gloc]:
self._Ps[gloc][mesh] = mesh.getInterpolationMat(self.locs, gloc)
P += [self._Ps[gloc][mesh]]
gloc = self.projGLoc
if mesh not in self._Ps[gloc]:
self._Ps[gloc][mesh] = mesh.getInterpolationMat(self.locs, gloc)
P = self._Ps[gloc][mesh]
return P
@@ -74,7 +61,7 @@ class TxFDEM(Survey.BaseTx):
freq = None #: Frequency (float)
rxListPair = RxListFDEM
rxListPair = RxFDEM
knownTxTypes = ['VMD']
@@ -85,35 +72,43 @@ class TxFDEM(Survey.BaseTx):
@property
def nD(self):
"""Number of data"""
return self.rxList.locs.shape[0]*len(self.rxList.rxTypes)
return self.vnD.sum()
@property
def vnD(self):
"""Vector number of data"""
return np.array([rx.locs.shape[0] for rx in self.rxList])
def projectFields(self, mesh, u):
nRt = len(self.rxList.rxTypes)
nRt = len(self.rxList)
Pu = range(nRt)
Ps = self.rxList.getP(mesh)
for ii, rx in enumerate(self.rxList.rxTypes):
fieldType = self.rxList._projField(rx)
u_part_complex = u[self, fieldType]
for ii, rx in enumerate(self.rxList):
P = rx.getP(mesh)
u_part_complex = u[self, rx.projField]
# get the real or imag component
real_or_imag = self.rxList._projComp(rx)
real_or_imag = rx.projComp
u_part = getattr(u_part_complex, real_or_imag)
Pu[ii] = Ps[ii]*u_part
Pu[ii] = P*u_part
return np.concatenate(Pu)
def projectFieldsDeriv(self, mesh, u, v, adjoint=False):
Ps = self.rxList.getP(mesh)
V = v.reshape((-1, len(Ps)), order='F')
Pvs = range(len(Ps))
for ii, rx in enumerate(self.rxList.rxTypes):
Pv = Ps[ii] * V[:, ii]
real_or_imag = self.rxList._projComp(rx)
for ii, rx in enumerate(self.rxList):
P = rx.getP(mesh)
if not adjoint:
Pv = Ps[ii] * V[:, ii]
elif adjoint:
Pv = Ps[ii].T * V[:, ii]
real_or_imag = rx.projComp
if real_or_imag == 'imag':
Pv = 1j*Pv
Pvs[ii] = 1j*Pv
elif real_or_imag == 'real':
Pv = Pv.astype(complex)
Pvs[ii] = Pv.astype(complex)
else:
raise NotImplementedError('must be real or imag')
+6 -6
View File
@@ -18,24 +18,24 @@ def getProblem(fdemType):
x = np.linspace(5,10,3)
XYZ = Utils.ndgrid(x,x,np.r_[0])
if fdemType == 'e':
rxList = EM.FDEM.RxListFDEM(XYZ, 'bxr,exi')
rxList = EM.FDEM.RxFDEM(XYZ, 'bxr')
elif fdemType == 'b':
rxList = EM.FDEM.RxListFDEM(XYZ, 'exi')
rxList = EM.FDEM.RxFDEM(XYZ, 'exi')
else:
raise NotImplementedError()
Tx0 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-2, rxList)
Tx0 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-2, [rxList])
x = np.linspace(5,10,3)
XYZ = Utils.ndgrid(x,x,np.r_[0])
if fdemType == 'e':
rxList = EM.FDEM.RxListFDEM(XYZ, 'eyi')
rxList = EM.FDEM.RxFDEM(XYZ, 'eyi')
elif fdemType == 'b':
rxList = EM.FDEM.RxListFDEM(XYZ, 'eyr')
rxList = EM.FDEM.RxFDEM(XYZ, 'eyr')
else:
raise NotImplementedError()
Tx1 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-4, rxList)
Tx1 = EM.FDEM.TxFDEM(np.r_[4.,2.,2.], 'VMD', 1e-4, [rxList])
survey = EM.FDEM.SurveyFDEM([Tx0, Tx1])
+14 -14
View File
@@ -9,15 +9,15 @@ class FieldsTest(unittest.TestCase):
x = np.linspace(5,10,3)
XYZ = Utils.ndgrid(x,x,np.r_[0.])
txLoc = np.r_[0,0,0.]
rxList0 = EM.FDEM.RxListFDEM(XYZ, 'exi,exr,eyi,eyr,ezi,ezr')
Tx0 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., rxList0)
rxList1 = EM.FDEM.RxListFDEM(XYZ, 'bxi,bxr,byi,byr,bzi,bzr')
Tx1 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., rxList1)
rxList2 = EM.FDEM.RxListFDEM(XYZ, 'bxi,eyr')
Tx2 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., rxList2)
rxList3 = EM.FDEM.RxListFDEM(XYZ, 'bxi')
Tx3 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., rxList3)
Tx4 = EM.FDEM.TxFDEM(txLoc, 'VMD', 1., rxList0)
rxList0 = EM.FDEM.RxFDEM(XYZ, 'exi')
Tx0 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., [rxList0])
rxList1 = EM.FDEM.RxFDEM(XYZ, 'bxi')
Tx1 = EM.FDEM.TxFDEM(txLoc, 'VMD', 3., [rxList1])
rxList2 = EM.FDEM.RxFDEM(XYZ, 'bxi')
Tx2 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., [rxList2])
rxList3 = EM.FDEM.RxFDEM(XYZ, 'bxi')
Tx3 = EM.FDEM.TxFDEM(txLoc, 'VMD', 2., [rxList3])
Tx4 = EM.FDEM.TxFDEM(txLoc, 'VMD', 1., [rxList0, rxList1, rxList2, rxList3])
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
survey = EM.FDEM.SurveyFDEM(txList)
self.F = EM.FDEM.FieldsFDEM(mesh, survey)
@@ -101,13 +101,13 @@ class FieldsTest(unittest.TestCase):
for ii, tx in enumerate(Txs):
dat = tx.projectFields(self.mesh, F)
self.assertTrue(dat.dtype == float)
dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList.rxTypes)), order='F')
for jj, rx in enumerate(tx.rxList.rxTypes):
fieldType = tx.rxList._projField(rx)
dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList)), order='F')
for jj, rx in enumerate(tx.rxList):
fieldType = rx.projField
u = {'b':b[:,ii], 'e': e[:,ii]}[fieldType]
real_or_imag = tx.rxList._projComp(rx)
real_or_imag = rx.projComp
u = getattr(u, real_or_imag)
gloc = tx.rxList._projGLoc(rx)
gloc = rx.projGLoc
d = self.mesh.getInterpolationMat(self.XYZ, gloc)*u
self.assertTrue(np.all(dat[:, jj] == d))