mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 11:55:17 +08:00
Adding fixes to code.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
+59
-52
@@ -165,11 +165,16 @@ class MTProblem(Problem.BaseProblem):
|
||||
:rtype: scipy.sparse.csr_matrix
|
||||
:return: A
|
||||
"""
|
||||
mui = self.MfMui
|
||||
from SimPEG import Mesh
|
||||
Mback = Mesh.TensorMesh(self.mesh.h,self.mesh.x0)
|
||||
Mback.setCellGradBC('dirichlet')
|
||||
mui = Mback.getFaceInnerProduct(1/mu_0)
|
||||
sigmaBG = self.backModel
|
||||
MsigBG = Mback.getEdgeInnerProduct(sigmaBG)
|
||||
sigBG = self.MeSigmaBG
|
||||
C = self.mesh.edgeCurl
|
||||
C = Mback.edgeCurl
|
||||
|
||||
return C.T*mui*C - 1j*omega(freq)*sigBG
|
||||
return C.T*mui*C - 1j*omega(freq)*MsigBG
|
||||
|
||||
def getADeriv(self, freq, u, v, adjoint=False):
|
||||
sig = self.curTModel
|
||||
@@ -198,7 +203,7 @@ class MTProblem(Problem.BaseProblem):
|
||||
eBG_bp = homo1DModelSource(self.mesh,freq,backSigma)
|
||||
Abg = self.getAbg(freq)
|
||||
|
||||
return -Abg*eBG_bp, eBG_bp
|
||||
return Abg*eBG_bp, eBG_bp
|
||||
|
||||
##################################################################
|
||||
# Inversion stuff
|
||||
@@ -228,70 +233,72 @@ class MTProblem(Problem.BaseProblem):
|
||||
|
||||
|
||||
def Jvec(self, m, v, u=None):
|
||||
if u is None:
|
||||
u = self.fields(m)
|
||||
# if u is None:
|
||||
# u = self.fields(m)
|
||||
|
||||
self.curModel = m
|
||||
# self.curModel = m
|
||||
|
||||
Jv = self.dataPair(self.survey)
|
||||
# Jv = self.dataPair(self.survey)
|
||||
|
||||
for freq in self.survey.freqs:
|
||||
A = self.getA(freq)
|
||||
solver = self.Solver(A, **self.solverOpts)
|
||||
# for freq in self.survey.freqs:
|
||||
# A = self.getA(freq)
|
||||
# solver = self.Solver(A, **self.solverOpts)
|
||||
|
||||
for tx in self.survey.getTransmitters(freq):
|
||||
u_tx = u[tx, self.solType]
|
||||
w = self.getADeriv(freq, u_tx, v)
|
||||
Ainvw = solver.solve(w)
|
||||
for rx in tx.rxList:
|
||||
fAinvw = self.calcFields(Ainvw, freq, rx.projField)
|
||||
P = lambda v: rx.projectFieldsDeriv(tx, self.mesh, u, v)
|
||||
# for tx in self.survey.getTransmitters(freq):
|
||||
# u_tx = u[tx, self.solType]
|
||||
# w = self.getADeriv(freq, u_tx, v)
|
||||
# Ainvw = solver.solve(w)
|
||||
# for rx in tx.rxList:
|
||||
# fAinvw = self.calcFields(Ainvw, freq, rx.projField)
|
||||
# P = lambda v: rx.projectFieldsDeriv(tx, self.mesh, u, v)
|
||||
|
||||
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, v)
|
||||
if df_dm is None:
|
||||
Jv[tx, rx] = - P(fAinvw)
|
||||
else:
|
||||
Jv[tx, rx] = - P(fAinvw) + P(df_dm)
|
||||
# df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, v)
|
||||
# if df_dm is None:
|
||||
# Jv[tx, rx] = - P(fAinvw)
|
||||
# else:
|
||||
# Jv[tx, rx] = - P(fAinvw) + P(df_dm)
|
||||
|
||||
return Utils.mkvc(Jv)
|
||||
# return Utils.mkvc(Jv)
|
||||
pass
|
||||
|
||||
def Jtvec(self, m, v, u=None):
|
||||
if u is None:
|
||||
u = self.fields(m)
|
||||
# if u is None:
|
||||
# u = self.fields(m)
|
||||
|
||||
self.curModel = m
|
||||
# self.curModel = m
|
||||
|
||||
# Ensure v is a data object.
|
||||
if not isinstance(v, self.dataPair):
|
||||
v = self.dataPair(self.survey, v)
|
||||
# # Ensure v is a data object.
|
||||
# if not isinstance(v, self.dataPair):
|
||||
# v = self.dataPair(self.survey, v)
|
||||
|
||||
Jtv = np.zeros(self.mapping.nP)
|
||||
# Jtv = np.zeros(self.mapping.nP)
|
||||
|
||||
for freq in self.survey.freqs:
|
||||
AT = self.getA(freq).T
|
||||
solver = self.Solver(AT, **self.solverOpts)
|
||||
# for freq in self.survey.freqs:
|
||||
# AT = self.getA(freq).T
|
||||
# solver = self.Solver(AT, **self.solverOpts)
|
||||
|
||||
for tx in self.survey.getTransmitters(freq):
|
||||
u_tx = u[tx, self.solType]
|
||||
# for tx in self.survey.getTransmitters(freq):
|
||||
# u_tx = u[tx, self.solType]
|
||||
|
||||
for rx in tx.rxList:
|
||||
PTv = rx.projectFieldsDeriv(tx, self.mesh, u, v[tx, rx], adjoint=True)
|
||||
fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True)
|
||||
# for rx in tx.rxList:
|
||||
# PTv = rx.projectFieldsDeriv(tx, self.mesh, u, v[tx, rx], adjoint=True)
|
||||
# fPTv = self.calcFields(PTv, freq, rx.projField, adjoint=True)
|
||||
|
||||
w = solver.solve( fPTv )
|
||||
Jtv_rx = - self.getADeriv(freq, u_tx, w, adjoint=True)
|
||||
# w = solver.solve( fPTv )
|
||||
# Jtv_rx = - self.getADeriv(freq, u_tx, w, adjoint=True)
|
||||
|
||||
df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, PTv, adjoint=True)
|
||||
# df_dm = self.calcFieldsDeriv(u_tx, freq, rx.projField, PTv, adjoint=True)
|
||||
|
||||
if df_dm is not None:
|
||||
Jtv_rx += df_dm
|
||||
# if df_dm is not None:
|
||||
# Jtv_rx += df_dm
|
||||
|
||||
real_or_imag = rx.projComp
|
||||
if real_or_imag == 'real':
|
||||
Jtv += Jtv_rx.real
|
||||
elif real_or_imag == 'imag':
|
||||
Jtv += - Jtv_rx.real
|
||||
else:
|
||||
raise Exception('Must be real or imag')
|
||||
# real_or_imag = rx.projComp
|
||||
# if real_or_imag == 'real':
|
||||
# Jtv += Jtv_rx.real
|
||||
# elif real_or_imag == 'imag':
|
||||
# Jtv += - Jtv_rx.real
|
||||
# else:
|
||||
# raise Exception('Must be real or imag')
|
||||
|
||||
return Jtv
|
||||
# return Jtv
|
||||
pass
|
||||
+1
-10
@@ -129,7 +129,6 @@ class RxMT(Survey.BaseRx):
|
||||
return Pv
|
||||
|
||||
|
||||
# Call this Source or polarization or something...?
|
||||
# Note: Might need to add tests to make sure that both polarization have the same rxList.
|
||||
class srcMT(Survey.BaseSrc):
|
||||
'''
|
||||
@@ -195,7 +194,7 @@ class SurveyMT(Survey.BaseSurvey):
|
||||
|
||||
# TODO: Rename to getSources
|
||||
def getSources(self, freq):
|
||||
"""Returns the transmitters associated with a specific frequency."""
|
||||
"""Returns the sources associated with a specific frequency."""
|
||||
assert freq in self._freqDict, "The requested frequency is not in this survey."
|
||||
return self._freqDict[freq]
|
||||
|
||||
@@ -270,14 +269,6 @@ class DataMT(Survey.Data):
|
||||
outArr[comp] = outTemp[comp].copy()
|
||||
for comp in ['zxx','zxy','zyx','zyy']:
|
||||
outArr[comp] = outTemp[comp+'r'].copy() + 1j*outTemp[comp+'i'].copy()
|
||||
# for uniFL in uniFLmarr:
|
||||
# mTemp = mkvc(rec2ndarr(mArrRec[np.ma.where(mArrRec[['freq','x','y','z']].data == np.array(uniFL))][impList]).sum(axis=0),2).T
|
||||
# compBlock = np.sum(mTemp.data.reshape((4,2))*np.array([[1,1j],[1,1j],[1,1j],[1,1j]]),axis=1).copy().view(dt[4::])
|
||||
# dataBlock = mkvc(recFunc.merge_arrays((np.array(uniFL),compBlock),flatten=True),2).T
|
||||
# try:
|
||||
# outArr = recFunc.stack_arrays((outArr,dataBlock),usemask=False)
|
||||
# except NameError as e:
|
||||
# outArr = dataBlock
|
||||
|
||||
# Return
|
||||
return outArr
|
||||
Reference in New Issue
Block a user