Projections and derivatives working.

This commit is contained in:
rowanc1
2014-03-22 13:22:44 -07:00
parent b7d47bcf38
commit 9acaae66db
4 changed files with 55 additions and 73 deletions
+18 -29
View File
@@ -110,15 +110,15 @@ class BaseProblemFDEM(Problem.BaseProblem):
u_tx = u[tx, self.solType]
w = self.getADeriv(freq, u_tx, v)
Ainvw = solver.solve(w)
fAinvw = self._calcFieldsList(Ainvw, freq, tx.rxList.fieldTypes)
P = lambda v: tx.projectFieldsDeriv(self.mesh, u, v)
for rx in tx.rxList:
fAinvw = self.calcFields(Ainvw, freq, rx.projField)
P = lambda v: rx.projectFieldsDeriv(tx, self.mesh, u, v)
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, v)
#TODO: this is now a list?
if df_dm is None:
Jv[tx] = - P(fAinvw)
else:
Jv[tx] = - P(fAinvw) + P(df_dm)
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, v)
if df_dm is None:
Jv[tx, rx] = - P(fAinvw)
else:
Jv[tx, rx] = - P(fAinvw) + P(df_dm)
return Utils.mkvc(Jv)
@@ -141,33 +141,22 @@ class BaseProblemFDEM(Problem.BaseProblem):
for tx in self.survey.getTransmitters(freq):
u_tx = u[tx, self.solType]
PTv = tx.projectFieldsDeriv(self.mesh, u, v[tx], adjoint=True)
fPTv = self._calcFieldsList(PTv, freq, tx.rxList.fieldTypes, adjoint=True)
for rx in tx.rxList:
PTv = rx.projectFieldsDeriv(tx, self.mesh, u, v[tx, rx], adjoint=True)
fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True)
w = solver.solve( fPTv )
Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True)
w = solver.solve( fPTv )
Jtv_tx = self.getADeriv(freq, u_tx, w, adjoint=True)
df_dm = self._calcFieldsDerivList(u_tx, freq, tx.rxList.fieldTypes, PTv, adjoint=True)
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, PTv, adjoint=True)
if df_dm is None:
Jtv += - Jtv_tx
else:
Jtv += - Jtv_tx + df_dm
if df_dm is None:
Jtv += - Jtv_tx
else:
Jtv += - Jtv_tx + df_dm
return Jtv
def _calcFieldsList(self, sol, freq, fieldTypes, adjoint=False):
fVecs = range(len(fieldTypes))
for ii, fieldType in enumerate(fieldTypes):
fVecs[ii] = self.calcFields(sol, freq, fieldType, adjoint=adjoint)
return np.concatenate(fVecs)
def _calcFieldsDerivList(self, sol, freq, fieldTypes, v, adjoint=False):
fVecs = range(len(fieldTypes))
V = v.reshape((-1, len(fieldTypes)), order='F')
for ii, fieldType in enumerate(fieldTypes):
fVecs[ii] = self.calcFieldsDeriv(sol, freq, fieldType, V[:,ii], adjoint=adjoint)
return np.concatenate(fVecs)
class ProblemFDEM_e(BaseProblemFDEM):
"""
+31 -38
View File
@@ -61,6 +61,35 @@ class RxFDEM(Survey.BaseRx):
P = self._Ps[gloc][mesh]
return P
def projectFields(self, tx, mesh, u):
P = self.getP(mesh)
u_part_complex = u[tx, self.projField]
# get the real or imag component
real_or_imag = self.projComp
u_part = getattr(u_part_complex, real_or_imag)
return P*u_part
def projectFieldsDeriv(self, tx, mesh, u, v, adjoint=False):
P = self.getP(mesh)
if not adjoint:
Pv_complex = P * v
#TODO: check this deriv...
real_or_imag = self.projComp
Pv = getattr(Pv_complex, real_or_imag)
elif adjoint:
Pv_real = P.T * v
real_or_imag = self.projComp
if real_or_imag == 'imag':
Pv = 1j*Pv_real
elif real_or_imag == 'real':
Pv = Pv_real.astype(complex)
else:
raise NotImplementedError('must be real or imag')
return Pv
class TxFDEM(Survey.BaseTx):
@@ -84,41 +113,6 @@ class TxFDEM(Survey.BaseTx):
"""Vector number of data"""
return np.array([rx.nD for rx in self.rxList])
def projectFields(self, mesh, u):
nRt = len(self.rxList)
Pu = range(nRt)
for ii, rx in enumerate(self.rxList):
P = rx.getP(mesh)
u_part_complex = u[self, rx.projField]
# get the real or imag component
real_or_imag = rx.projComp
u_part = getattr(u_part_complex, real_or_imag)
Pu[ii] = P*u_part
return np.concatenate(Pu)
def projectFieldsDeriv(self, mesh, u, v, adjoint=False):
V = v.reshape((-1, len(Ps)), order='F')
Pvs = range(len(Ps))
for ii, rx in enumerate(self.rxList):
P = rx.getP(mesh)
if not adjoint:
Pv = Ps[ii] * V[:, ii]
elif adjoint:
Pv = Ps[ii].T * V[:, ii]
real_or_imag = rx.projComp
if real_or_imag == 'imag':
Pvs[ii] = 1j*Pv
elif real_or_imag == 'real':
Pvs[ii] = Pv.astype(complex)
else:
raise NotImplementedError('must be real or imag')
return np.concatenate(Pvs)
class FieldsFDEM(object):
"""Fancy Field Storage for a FDEM survey."""
@@ -285,8 +279,6 @@ class DataFDEM(object):
indBot += rx.nD
class SurveyFDEM(Survey.BaseSurvey):
"""
docstring for SurveyFDEM
@@ -349,7 +341,8 @@ class SurveyFDEM(Survey.BaseSurvey):
def projectFields(self, u):
data = DataFDEM(self)
for tx in self.txList:
data[tx] = tx.projectFields(self.mesh, u)
for rx in tx.rxList:
data[tx, rx] = rx.projectFields(tx, self.mesh, u)
return data
def projectFieldsDeriv(self, u):
+3 -2
View File
@@ -70,8 +70,9 @@ class FDEM_DerivTests_e(unittest.TestCase):
m = self.sigma
u = self.prb.fields(m)
vJw = v.dot(self.prb.Jvec(m, w, u=u))
wJtv = w.dot(self.prb.Jtvec(m, v, u=u))
vJw = np.vdot(v, self.prb.Jvec(m, w, u=u))
wJtv = np.vdot(w, self.prb.Jtvec(m, v, u=u))
print 'Jtvec: ', vJw - wJtv
self.assertTrue(vJw - wJtv < TOL)
+3 -4
View File
@@ -100,17 +100,16 @@ class FieldsTest(unittest.TestCase):
Txs = F.survey.getTransmitters(freq)
for ii, tx in enumerate(Txs):
dat = tx.projectFields(self.mesh, F)
self.assertTrue(dat.dtype == float)
dat = dat.reshape((self.XYZ.shape[0], len(tx.rxList)), order='F')
for jj, rx in enumerate(tx.rxList):
dat = rx.projectFields(tx, self.mesh, F)
self.assertTrue(dat.dtype == float)
fieldType = rx.projField
u = {'b':b[:,ii], 'e': e[:,ii]}[fieldType]
real_or_imag = rx.projComp
u = getattr(u, real_or_imag)
gloc = rx.projGLoc
d = self.mesh.getInterpolationMat(self.XYZ, gloc)*u
self.assertTrue(np.all(dat[:, jj] == d))
self.assertTrue(np.all(dat == d))