mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 04:19:45 +08:00
Projections and derivatives working.
This commit is contained in:
+18
-29
@@ -110,15 +110,15 @@ class BaseProblemFDEM(Problem.BaseProblem):
|
||||
u_tx = u[tx, self.solType]
|
||||
w = self.getADeriv(freq, u_tx, v)
|
||||
Ainvw = solver.solve(w)
|
||||
fAinvw = self._calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes)
|
||||
P = lambda v: tx.projectFieldsDeriv(self.mesh, u, v)
|
||||
for rx in tx.rxList:
|
||||
fAinvw = self.calcFields(Ainvw, freq, rx.projField)
|
||||
P = lambda v: rx.projectFieldsDeriv(tx, self.mesh, u, v)
|
||||
|
||||
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v)
|
||||
#TODO: this is now a list?
|
||||
if df_dm is None:
|
||||
Jv[tx] = - P(fAinvw)
|
||||
else:
|
||||
Jv[tx] = - P(fAinvw) + P(df_dm)
|
||||
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, v)
|
||||
if df_dm is None:
|
||||
Jv[tx, rx] = - P(fAinvw)
|
||||
else:
|
||||
Jv[tx, rx] = - P(fAinvw) + P(df_dm)
|
||||
|
||||
return Utils.mkvc(Jv)
|
||||
|
||||
@@ -141,33 +141,22 @@ class BaseProblemFDEM(Problem.BaseProblem):
|
||||
for tx in self.survey.getTransmitters(freq):
|
||||
u_tx = u[tx, self.solType]
|
||||
|
||||
PTv = tx.projectFieldsDeriv(self.mesh, u, v[tx], adjoint=True)
|
||||
fPTv = self._calcFieldsList(PTv, freq, tx.rxList.fieldTypes, adjoint=True)
|
||||
for rx in tx.rxList:
|
||||
PTv = rx.projectFieldsDeriv(tx, self.mesh, u, v[tx, rx], adjoint=True)
|
||||
fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True)
|
||||
|
||||
w = solver.solve( fPTv )
|
||||
Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True)
|
||||
w = solver.solve( fPTv )
|
||||
Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True)
|
||||
|
||||
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True)
|
||||
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, PTv, adjoint=True)
|
||||
|
||||
if df_dm is None:
|
||||
Jtv += - Jtv_tx
|
||||
else:
|
||||
Jtv += - Jtv_tx + df_dm
|
||||
if df_dm is None:
|
||||
Jtv += - Jtv_tx
|
||||
else:
|
||||
Jtv += - Jtv_tx + df_dm
|
||||
|
||||
return Jtv
|
||||
|
||||
def _calcFieldsList(self, sol, freq, fieldTypes, adjoint=False):
|
||||
fVecs = range(len(fieldTypes))
|
||||
for ii, fieldType in enumerate(fieldTypes):
|
||||
fVecs[ii] = self.calcFields(sol, freq, fieldType, adjoint=adjoint)
|
||||
return np.concatenate(fVecs)
|
||||
|
||||
def _calcFieldsDerivList(self, sol, freq, fieldTypes, v, adjoint=False):
|
||||
fVecs = range(len(fieldTypes))
|
||||
V = v.reshape((-1, len(fieldTypes)), order='F')
|
||||
for ii, fieldType in enumerate(fieldTypes):
|
||||
fVecs[ii] = self.calcFieldsDeriv(sol, freq, fieldType, V[:,ii], adjoint=adjoint)
|
||||
return np.concatenate(fVecs)
|
||||
|
||||
class ProblemFDEM_e(BaseProblemFDEM):
|
||||
"""
|
||||
|
||||
+31
-38
@@ -61,6 +61,35 @@ class RxFDEM(Survey.BaseRx):
|
||||
P = self._Ps[gloc][mesh]
|
||||
return P
|
||||
|
||||
def projectFields(self, tx, mesh, u):
|
||||
P = self.getP(mesh)
|
||||
u_part_complex = u[tx, self.projField]
|
||||
# get the real or imag component
|
||||
real_or_imag = self.projComp
|
||||
u_part = getattr(u_part_complex, real_or_imag)
|
||||
return P*u_part
|
||||
|
||||
def projectFieldsDeriv(self, tx, mesh, u, v, adjoint=False):
|
||||
P = self.getP(mesh)
|
||||
|
||||
if not adjoint:
|
||||
Pv_complex = P * v
|
||||
#TODO: check this deriv...
|
||||
real_or_imag = self.projComp
|
||||
Pv = getattr(Pv_complex, real_or_imag)
|
||||
elif adjoint:
|
||||
Pv_real = P.T * v
|
||||
|
||||
real_or_imag = self.projComp
|
||||
if real_or_imag == 'imag':
|
||||
Pv = 1j*Pv_real
|
||||
elif real_or_imag == 'real':
|
||||
Pv = Pv_real.astype(complex)
|
||||
else:
|
||||
raise NotImplementedError('must be real or imag')
|
||||
|
||||
return Pv
|
||||
|
||||
|
||||
class TxFDEM(Survey.BaseTx):
|
||||
|
||||
@@ -84,41 +113,6 @@ class TxFDEM(Survey.BaseTx):
|
||||
"""Vector number of data"""
|
||||
return np.array([rx.nD for rx in self.rxList])
|
||||
|
||||
def projectFields(self, mesh, u):
|
||||
|
||||
nRt = len(self.rxList)
|
||||
Pu = range(nRt)
|
||||
|
||||
for ii, rx in enumerate(self.rxList):
|
||||
P = rx.getP(mesh)
|
||||
u_part_complex = u[self, rx.projField]
|
||||
# get the real or imag component
|
||||
real_or_imag = rx.projComp
|
||||
u_part = getattr(u_part_complex, real_or_imag)
|
||||
Pu[ii] = P*u_part
|
||||
return np.concatenate(Pu)
|
||||
|
||||
def projectFieldsDeriv(self, mesh, u, v, adjoint=False):
|
||||
V = v.reshape((-1, len(Ps)), order='F')
|
||||
Pvs = range(len(Ps))
|
||||
for ii, rx in enumerate(self.rxList):
|
||||
P = rx.getP(mesh)
|
||||
|
||||
if not adjoint:
|
||||
Pv = Ps[ii] * V[:, ii]
|
||||
elif adjoint:
|
||||
Pv = Ps[ii].T * V[:, ii]
|
||||
|
||||
real_or_imag = rx.projComp
|
||||
if real_or_imag == 'imag':
|
||||
Pvs[ii] = 1j*Pv
|
||||
elif real_or_imag == 'real':
|
||||
Pvs[ii] = Pv.astype(complex)
|
||||
else:
|
||||
raise NotImplementedError('must be real or imag')
|
||||
|
||||
return np.concatenate(Pvs)
|
||||
|
||||
|
||||
class FieldsFDEM(object):
|
||||
"""Fancy Field Storage for a FDEM survey."""
|
||||
@@ -285,8 +279,6 @@ class DataFDEM(object):
|
||||
indBot += rx.nD
|
||||
|
||||
|
||||
|
||||
|
||||
class SurveyFDEM(Survey.BaseSurvey):
|
||||
"""
|
||||
docstring for SurveyFDEM
|
||||
@@ -349,7 +341,8 @@ class SurveyFDEM(Survey.BaseSurvey):
|
||||
def projectFields(self, u):
|
||||
data = DataFDEM(self)
|
||||
for tx in self.txList:
|
||||
data[tx] = tx.projectFields(self.mesh, u)
|
||||
for rx in tx.rxList:
|
||||
data[tx, rx] = rx.projectFields(tx, self.mesh, u)
|
||||
return data
|
||||
|
||||
def projectFieldsDeriv(self, u):
|
||||
|
||||
@@ -70,8 +70,9 @@ class FDEM_DerivTests_e(unittest.TestCase):
|
||||
|
||||
m = self.sigma
|
||||
u = self.prb.fields(m)
|
||||
vJw = v.dot(self.prb.Jvec(m, w, u=u))
|
||||
wJtv = w.dot(self.prb.Jtvec(m, v, u=u))
|
||||
vJw = np.vdot(v, self.prb.Jvec(m, w, u=u))
|
||||
wJtv = np.vdot(w, self.prb.Jtvec(m, v, u=u))
|
||||
print 'Jtvec: ', vJw - wJtv
|
||||
self.assertTrue(vJw - wJtv < TOL)
|
||||
|
||||
|
||||
|
||||
@@ -100,17 +100,16 @@ class FieldsTest(unittest.TestCase):
|
||||
|
||||
Txs = F.survey.getTransmitters(freq)
|
||||
for ii, tx in enumerate(Txs):
|
||||
dat = tx.projectFields(self.mesh, F)
|
||||
self.assertTrue(dat.dtype == float)
|
||||
dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList)), order='F')
|
||||
for jj, rx in enumerate(tx.rxList):
|
||||
dat = rx.projectFields(tx, self.mesh, F)
|
||||
self.assertTrue(dat.dtype == float)
|
||||
fieldType = rx.projField
|
||||
u = {'b':b[:,ii], 'e': e[:,ii]}[fieldType]
|
||||
real_or_imag = rx.projComp
|
||||
u = getattr(u, real_or_imag)
|
||||
gloc = rx.projGLoc
|
||||
d = self.mesh.getInterpolationMat(self.XYZ, gloc)*u
|
||||
self.assertTrue(np.all(dat[:, jj] == d))
|
||||
self.assertTrue(np.all(dat == d))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user