Merge pull request #175 from simpeg/em/ZeroIdentity

Em/zero identity
This commit is contained in:
Rowan Cockett
2015-11-24 19:09:08 -08:00
8 changed files with 110 additions and 235 deletions
+48 -139
View File
@@ -66,34 +66,28 @@ class BaseFDEMProblem(BaseEMProblem):
Jv = self.dataPair(self.survey)
for freq in self.survey.freqs:
dA_du = self.getA(freq) #
dA_duI = self.Solver(dA_du, **self.solverOpts)
A = self.getA(freq) #
Ainv = self.Solver(A, **self.solverOpts)
for src in self.survey.getSrcByFreq(freq):
ftype = self._fieldType + 'Solution'
u_src = f[src, ftype]
dA_dm = self.getADeriv_m(freq, u_src, v)
dRHS_dm = self.getRHSDeriv_m(src, v)
if dRHS_dm is None:
du_dm = dA_duI * ( - dA_dm )
else:
du_dm = dA_duI * ( - dA_dm + dRHS_dm )
dRHS_dm = self.getRHSDeriv_m(freq, src, v)
du_dm = Ainv * ( - dA_dm + dRHS_dm )
for rx in src.rxList:
# df_duFun = u.deriv_u(rx.fieldsUsed, m)
df_duFun = getattr(f, '_%sDeriv_u'%rx.projField, None)
df_du = df_duFun(src, du_dm, adjoint=False)
if df_du is not None:
du_dm = df_du
df_dudu_dm = df_duFun(src, du_dm, adjoint=False)
df_dmFun = getattr(f, '_%sDeriv_m'%rx.projField, None)
df_dm = df_dmFun(src, v, adjoint=False)
if df_dm is not None:
du_dm += df_dm
Df_Dm = np.array(df_dudu_dm + df_dm,dtype=complex)
P = lambda v: rx.projectFieldsDeriv(src, self.mesh, f, v) # wrt u, also have wrt m
Jv[src, rx] = P(du_dm)
Jv[src, rx] = P(Df_Dm)
return Utils.mkvc(Jv)
@@ -126,30 +120,23 @@ class BaseFDEMProblem(BaseEMProblem):
df_duTFun = getattr(f, '_%sDeriv_u'%rx.projField, None)
df_duT = df_duTFun(src, PTv, adjoint=True)
if df_duT is not None:
dA_duIT = ATinv * df_duT
else:
dA_duIT = ATinv * PTv
ATinvdf_duT = ATinv * df_duT
dA_dmT = self.getADeriv_m(freq, u_src, dA_duIT, adjoint=True)
dRHS_dmT = self.getRHSDeriv_m(src, dA_duIT, adjoint=True)
if dRHS_dmT is None:
du_dmT = - dA_dmT
else:
du_dmT = -dA_dmT + dRHS_dmT
dA_dmT = self.getADeriv_m(freq, u_src, ATinvdf_duT, adjoint=True)
dRHS_dmT = self.getRHSDeriv_m(freq,src, ATinvdf_duT, adjoint=True)
du_dmT = -dA_dmT + dRHS_dmT
df_dmFun = getattr(f, '_%sDeriv_m'%rx.projField, None)
dfT_dm = df_dmFun(src, PTv, adjoint=True)
if dfT_dm is not None:
du_dmT += dfT_dm
du_dmT += dfT_dm
real_or_imag = rx.projComp
if real_or_imag == 'real':
Jtv += du_dmT.real
elif real_or_imag == 'imag':
Jtv += - du_dmT.real
if real_or_imag is 'real':
Jtv += np.array(du_dmT,dtype=complex).real
elif real_or_imag is 'imag':
Jtv += - np.array(du_dmT,dtype=complex).real
else:
raise Exception('Must be real or imag')
@@ -173,10 +160,8 @@ class BaseFDEMProblem(BaseEMProblem):
for i, src in enumerate(Srcs):
smi, sei = src.eval(self)
if smi is not None:
S_m[:,i] = Utils.mkvc(smi)
if sei is not None:
S_e[:,i] = Utils.mkvc(sei)
S_m[:,i] = S_m[:,i] + smi
S_e[:,i] = S_e[:,i] + sei
return S_m, S_e
@@ -249,39 +234,21 @@ class Problem_e(BaseFDEMProblem):
C = self.mesh.edgeCurl
MfMui = self.MfMui
# RHS = C.T * (MfMui * S_m) -1j * omega(freq) * Me * S_e
RHS = C.T * (MfMui * S_m) -1j * omega(freq) * S_e
return RHS
def getRHSDeriv_m(self, src, v, adjoint=False):
def getRHSDeriv_m(self, freq, src, v, adjoint=False):
C = self.mesh.edgeCurl
MfMui = self.MfMui
S_mDeriv, S_eDeriv = src.evalDeriv(self, adjoint)
if adjoint:
dRHS = MfMui * (C * v)
S_mDerivv = S_mDeriv(dRHS)
S_eDerivv = S_eDeriv(v)
if S_mDerivv is not None and S_eDerivv is not None:
return S_mDerivv - 1j * omega(freq) * S_eDerivv
elif S_mDerivv is not None:
return S_mDerivv
elif S_eDerivv is not None:
return - 1j * omega(freq) * S_eDerivv
else:
return None
else:
S_mDerivv, S_eDerivv = S_mDeriv(v), S_eDeriv(v)
return S_mDeriv(dRHS) - 1j * omega(freq) * S_eDeriv(v)
if S_mDerivv is not None and S_eDerivv is not None:
return C.T * (MfMui * S_mDerivv) -1j * omega(freq) * S_eDerivv
elif S_mDerivv is not None:
return C.T * (MfMui * S_mDerivv)
elif S_eDerivv is not None:
return -1j * omega(freq) * S_eDerivv
else:
return None
else:
return C.T * (MfMui * S_mDeriv(v)) -1j * omega(freq) * S_eDeriv(v)
class Problem_b(BaseFDEMProblem):
@@ -362,7 +329,6 @@ class Problem_b(BaseFDEMProblem):
S_m, S_e = self.getSourceTerm(freq)
C = self.mesh.edgeCurl
MeSigmaI = self.MeSigmaI
# Me = self.Me
RHS = S_m + C * ( MeSigmaI * S_e )
@@ -372,51 +338,28 @@ class Problem_b(BaseFDEMProblem):
return RHS
def getRHSDeriv_m(self, src, v, adjoint=False):
def getRHSDeriv_m(self, freq, src, v, adjoint=False):
C = self.mesh.edgeCurl
S_m, S_e = src.eval(self)
MfMui = self.MfMui
# Me = self.Me
if self._makeASymmetric and adjoint:
v = self.MfMui * v
if S_e is not None:
MeSigmaIDeriv = self.MeSigmaIDeriv(S_e)
if not adjoint:
RHSderiv = C * (MeSigmaIDeriv * v)
elif adjoint:
RHSderiv = MeSigmaIDeriv.T * (C.T * v)
else:
RHSderiv = None
MeSigmaIDeriv = self.MeSigmaIDeriv(S_e)
S_mDeriv, S_eDeriv = src.evalDeriv(self, adjoint)
S_mDeriv, S_eDeriv = S_mDeriv(v), S_eDeriv(v)
if S_mDeriv is not None and S_eDeriv is not None:
if not adjoint:
SrcDeriv = S_mDeriv + C * (self.MeSigmaI * S_eDeriv)
elif adjoint:
SrcDeriv = S_mDeriv + Self.MeSigmaI.T * ( C.T * S_eDeriv)
elif S_mDeriv is not None:
SrcDeriv = S_mDeriv
elif S_eDeriv is not None:
if not adjoint:
SrcDeriv = C * (self.MeSigmaI * S_eDeriv)
elif adjoint:
SrcDeriv = self.MeSigmaI.T * ( C.T * S_eDeriv)
else:
SrcDeriv = None
if RHSderiv is not None and SrcDeriv is not None:
RHSderiv += SrcDeriv
elif SrcDeriv is not None:
RHSderiv = SrcDeriv
if not adjoint:
RHSderiv = C * (MeSigmaIDeriv * v)
SrcDeriv = S_mDeriv(v) + C * (self.MeSigmaI * S_eDeriv(v))
elif adjoint:
RHSderiv = MeSigmaIDeriv.T * (C.T * v)
SrcDeriv = S_mDeriv(v) + self.MeSigmaI.T * (C.T * S_eDeriv(v))
if RHSderiv is not None:
if self._makeASymmetric is True and not adjoint:
return MfMui.T * RHSderiv
if self._makeASymmetric is True and not adjoint:
return MfMui.T * (SrcDeriv + RHSderiv)
return RHSderiv
return RHSderiv + SrcDeriv
@@ -519,7 +462,7 @@ class Problem_j(BaseFDEMProblem):
return RHS
def getRHSDeriv_m(self, src, v, adjoint=False):
def getRHSDeriv_m(self, freq, src, v, adjoint=False):
C = self.mesh.edgeCurl
MeMuI = self.MeMuI
S_mDeriv, S_eDeriv = src.evalDeriv(self, adjoint)
@@ -528,27 +471,10 @@ class Problem_j(BaseFDEMProblem):
if self._makeASymmetric:
MfRho = self.MfRho
v = MfRho*v
S_mDerivv = S_mDeriv(MeMuI.T * (C.T * v))
S_eDerivv = S_eDeriv(v)
if S_mDerivv is not None and S_eDerivv is not None:
return S_mDerivv - 1j * omega(freq) * S_eDerivv
elif S_mDerivv is not None:
return S_mDerivv
elif S_eDerivv is not None:
return - 1j * omega(freq) * S_eDerivv
else:
return None
else:
S_mDerivv, S_eDerivv = S_mDeriv(v), S_eDeriv(v)
return S_mDeriv(MeMuI.T * (C.T * v)) - 1j * omega(freq) * S_eDeriv(v)
if S_mDerivv is not None and S_eDerivv is not None:
RHSDeriv = C * (MeMuI * S_mDerivv) - 1j * omega(freq) * S_eDerivv
elif S_mDerivv is not None:
RHSDeriv = C * (MeMuI * S_mDerivv)
elif S_eDerivv is not None:
RHSDeriv = - 1j * omega(freq) * S_eDerivv
else:
return None
else:
RHSDeriv = C * (MeMuI * S_mDeriv(v)) - 1j * omega(freq) * S_eDeriv(v)
if self._makeASymmetric:
MfRho = self.MfRho
@@ -627,35 +553,18 @@ class Problem_h(BaseFDEMProblem):
return RHS
def getRHSDeriv_m(self, src, v, adjoint=False):
def getRHSDeriv_m(self, freq, src, v, adjoint=False):
_, S_e = src.eval(self)
C = self.mesh.edgeCurl
MfRho = self.MfRho
RHSDeriv = None
if S_e is not None:
MfRhoDeriv = self.MfRhoDeriv(S_e)
if not adjoint:
RHSDeriv = C.T * (MfRhoDeriv * v)
elif adjoint:
RHSDeriv = MfRhoDeriv.T * (C * v)
MfRhoDeriv = self.MfRhoDeriv(S_e)
if not adjoint:
RHSDeriv = C.T * (MfRhoDeriv * v)
elif adjoint:
RHSDeriv = MfRhoDeriv.T * (C * v)
S_mDeriv, S_eDeriv = src.evalDeriv(self, adjoint)
S_mDeriv = S_mDeriv(v)
S_eDeriv = S_eDeriv(v)
if S_mDeriv is not None:
if RHSDeriv is not None:
RHSDeriv += S_mDeriv(v)
else:
RHSDeriv = S_mDeriv(v)
if S_eDeriv is not None:
if RHSDeriv is not None:
RHSDeriv += C.T * (MfRho * S_e)
else:
RHSDeriv = C.T * (MfRho * S_e)
return RHSDeriv
return RHSDeriv + S_mDeriv(v) + C.T * (MfRho * S_eDeriv(v))
+28 -47
View File
@@ -3,6 +3,7 @@ import scipy.sparse as sp
import SimPEG
from SimPEG import Utils
from SimPEG.EM.Utils import omega
from SimPEG.Utils import Zero, Identity
class Fields(SimPEG.Problem.Fields):
@@ -32,8 +33,7 @@ class Fields_e(Fields):
ePrimary = np.zeros_like(eSolution)
for i, src in enumerate(srcList):
ep = src.ePrimary(self.prob)
if ep is not None:
ePrimary[:,i] = ep
ePrimary[:,i] = ePrimary[:,i] + ep
return ePrimary
def _eSecondary(self, eSolution, srcList):
@@ -43,18 +43,17 @@ class Fields_e(Fields):
return self._ePrimary(eSolution,srcList) + self._eSecondary(eSolution,srcList)
def _eDeriv_u(self, src, v, adjoint = False):
return None
return Identity()*v
def _eDeriv_m(self, src, v, adjoint = False):
# assuming primary does not depend on the model
return None
return Zero()
def _bPrimary(self, eSolution, srcList):
bPrimary = np.zeros([self._edgeCurl.shape[0],eSolution.shape[1]],dtype = complex)
for i, src in enumerate(srcList):
bp = src.bPrimary(self.prob)
if bp is not None:
bPrimary[:,i] += bp
bPrimary[:,i] = bPrimary[:,i] + bp
return bPrimary
def _bSecondary(self, eSolution, srcList):
@@ -63,8 +62,7 @@ class Fields_e(Fields):
for i, src in enumerate(srcList):
b[:,i] *= - 1./(1j*omega(src.freq))
S_m, _ = src.eval(self.prob)
if S_m is not None:
b[:,i] += 1./(1j*omega(src.freq)) * S_m
b[:,i] = b[:,i]+ 1./(1j*omega(src.freq)) * S_m
return b
def _bSecondaryDeriv_u(self, src, v, adjoint = False):
@@ -76,9 +74,7 @@ class Fields_e(Fields):
def _bSecondaryDeriv_m(self, src, v, adjoint = False):
S_mDeriv, _ = src.evalDeriv(self.prob, adjoint)
S_mDeriv = S_mDeriv(v)
if S_mDeriv is not None:
return 1./(1j * omega(src.freq)) * S_mDeriv
return None
return 1./(1j * omega(src.freq)) * S_mDeriv
def _b(self, eSolution, srcList):
return self._bPrimary(eSolution, srcList) + self._bSecondary(eSolution, srcList)
@@ -118,8 +114,7 @@ class Fields_b(Fields):
bPrimary = np.zeros_like(bSolution)
for i, src in enumerate(srcList):
bp = src.bPrimary(self.prob)
if bp is not None:
bPrimary[:,i] = bp
bPrimary[:,i] = bPrimary[:,i] + bp
return bPrimary
def _bSecondary(self, bSolution, srcList):
@@ -129,26 +124,24 @@ class Fields_b(Fields):
return self._bPrimary(bSolution, srcList) + self._bSecondary(bSolution, srcList)
def _bDeriv_u(self, src, v, adjoint=False):
return None
return Identity()*v
def _bDeriv_m(self, src, v, adjoint=False):
# assuming primary does not depend on the model
return None
return Zero()
def _ePrimary(self, bSolution, srcList):
ePrimary = np.zeros([self._edgeCurl.shape[1],bSolution.shape[1]],dtype = complex)
for i,src in enumerate(srcList):
ep = src.ePrimary(self.prob)
if ep is not None:
ePrimary[:,i] = ep
ePrimary[:,i] = ePrimary[:,i] + ep
return ePrimary
def _eSecondary(self, bSolution, srcList):
e = self._MeSigmaI * ( self._edgeCurl.T * ( self._MfMui * bSolution))
for i,src in enumerate(srcList):
_,S_e = src.eval(self.prob)
if S_e is not None:
e[:,i] += -self._MeSigmaI * S_e
e[:,i] = e[:,i]+ -self._MeSigmaI * S_e
return e
def _eSecondaryDeriv_u(self, src, v, adjoint=False):
@@ -166,8 +159,7 @@ class Fields_b(Fields):
Me = Me.T
w = self._edgeCurl.T * (self._MfMui * bSolution)
if S_e is not None:
w += -Utils.mkvc(Me * S_e,2)
w = w - Utils.mkvc(Me * S_e,2)
if not adjoint:
de_dm = self._MeSigmaIDeriv(w) * v
@@ -177,8 +169,7 @@ class Fields_b(Fields):
_, S_eDeriv = src.evalDeriv(self.prob, adjoint)
Se_Deriv = S_eDeriv(v)
if Se_Deriv is not None:
de_dm += -self._MeSigmaI * Se_Deriv
de_dm = de_dm - self._MeSigmaI * Se_Deriv
return de_dm
@@ -219,8 +210,7 @@ class Fields_j(Fields):
jPrimary = np.zeros_like(jSolution,dtype = complex)
for i, src in enumerate(srcList):
jp = src.jPrimary(self.prob)
if jp is not None:
jPrimary[:,i] += jp
jPrimary[:,i] = jPrimary[:,i] + jp
return jPrimary
def _jSecondary(self, jSolution, srcList):
@@ -230,18 +220,17 @@ class Fields_j(Fields):
return self._jPrimary(jSolution, srcList) + self._jSecondary(jSolution, srcList)
def _jDeriv_u(self, src, v, adjoint=False):
return None
return Identity()*v
def _jDeriv_m(self, src, v, adjoint=False):
# assuming primary does not depend on the model
return None
return Zero()
def _hPrimary(self, jSolution, srcList):
hPrimary = np.zeros([self._edgeCurl.shape[1],jSolution.shape[1]],dtype = complex)
for i, src in enumerate(srcList):
hp = src.hPrimary(self.prob)
if hp is not None:
hPrimary[:,i] = hp
hPrimary[:,i] = hPrimary[:,i] + hp
return hPrimary
def _hSecondary(self, jSolution, srcList):
@@ -249,8 +238,7 @@ class Fields_j(Fields):
for i, src in enumerate(srcList):
h[:,i] *= -1./(1j*omega(src.freq))
S_m,_ = src.eval(self.prob)
if S_m is not None:
h[:,i] += 1./(1j*omega(src.freq)) * self._MeMuI * (S_m)
h[:,i] = h[:,i]+ 1./(1j*omega(src.freq)) * self._MeMuI * (S_m)
return h
def _hSecondaryDeriv_u(self, src, v, adjoint=False):
@@ -276,12 +264,10 @@ class Fields_j(Fields):
if not adjoint:
S_mDeriv = S_mDeriv(v)
if S_mDeriv is not None:
hDeriv_m += 1./(1j*omega(src.freq)) * MeMuI * (Me * S_mDeriv)
hDeriv_m = hDeriv_m + 1./(1j*omega(src.freq)) * MeMuI * (Me * S_mDeriv)
elif adjoint:
S_mDeriv = S_mDeriv(Me.T * (MeMuI.T * v))
if S_mDeriv is not None:
hDeriv_m += 1./(1j*omega(src.freq)) * S_mDeriv
hDeriv_m = hDeriv_m + 1./(1j*omega(src.freq)) * S_mDeriv
return hDeriv_m
@@ -320,9 +306,8 @@ class Fields_h(Fields):
hPrimary = np.zeros_like(hSolution,dtype = complex)
for i, src in enumerate(srcList):
hp = src.hPrimary(self.prob)
if hp is not None:
hPrimary[:,i] += hp
return hPrimary
hPrimary[:,i] = hPrimary[:,i] + hp
return hPrimary
def _hSecondary(self, hSolution, srcList):
return hSolution
@@ -331,26 +316,24 @@ class Fields_h(Fields):
return self._hPrimary(hSolution, srcList) + self._hSecondary(hSolution, srcList)
def _hDeriv_u(self, src, v, adjoint=False):
return None
return Identity()*v
def _hDeriv_m(self, src, v, adjoint=False):
# assuming primary does not depend on the model
return None
return Zero()
def _jPrimary(self, hSolution, srcList):
jPrimary = np.zeros([self._edgeCurl.shape[0], hSolution.shape[1]], dtype = complex)
for i, src in enumerate(srcList):
jp = src.jPrimary(self.prob)
if jp is not None:
jPrimary[:,i] = jp
jPrimary[:,i] = jPrimary[:,i] + jp
return jPrimary
def _jSecondary(self, hSolution, srcList):
j = self._edgeCurl*hSolution
for i, src in enumerate(srcList):
_,S_e = src.eval(self.prob)
if S_e is not None:
j[:,i] += -S_e
j[:,i] = j[:,i]+ -S_e
return j
def _jSecondaryDeriv_u(self, src, v, adjoint=False):
@@ -362,9 +345,7 @@ class Fields_h(Fields):
def _jSecondaryDeriv_m(self, src, v, adjoint=False):
_,S_eDeriv = src.evalDeriv(self.prob, adjoint)
S_eDeriv = S_eDeriv(v)
if S_eDeriv is not None:
return -S_eDeriv
return None
return -S_eDeriv
def _j(self, hSolution, srcList):
return self._jPrimary(hSolution, srcList) + self._jSecondary(hSolution, srcList)
+14 -45
View File
@@ -1,6 +1,7 @@
from SimPEG import Survey, Problem, Utils, np, sp
from scipy.constants import mu_0
from SimPEG.EM.Utils import *
from SimPEG.Utils import Zero
# from SurveyFDEM import Rx
@@ -18,28 +19,28 @@ class BaseSrc(Survey.BaseSrc):
return lambda v: self.S_mDeriv(prob,v,adjoint), lambda v: self.S_eDeriv(prob,v,adjoint)
def bPrimary(self, prob):
return None
return Zero()
def hPrimary(self, prob):
return None
return Zero()
def ePrimary(self, prob):
return None
return Zero()
def jPrimary(self, prob):
return None
return Zero()
def S_m(self, prob):
return None
return Zero()
def S_e(self, prob):
return None
return Zero()
def S_mDeriv(self, prob, v, adjoint = False):
return None
return Zero()
def S_eDeriv(self, prob, v, adjoint = False):
return None
return Zero()
class RawVec_e(BaseSrc):
@@ -51,30 +52,14 @@ class RawVec_e(BaseSrc):
:param rxList: receiver list
"""
def __init__(self, rxList, freq, S_e, ePrimary=None, bPrimary=None, hPrimary=None, jPrimary=None):
def __init__(self, rxList, freq, S_e): #, ePrimary=None, bPrimary=None, hPrimary=None, jPrimary=None):
self._S_e = np.array(S_e,dtype=complex)
self._ePrimary = ePrimary
self._bPrimary = bPrimary
self._hPrimary = hPrimary
self._jPrimary = jPrimary
self.freq = float(freq)
BaseSrc.__init__(self, rxList)
def S_e(self, prob):
return self._S_e
def ePrimary(self, prob):
return self._ePrimary
def bPrimary(self, prob):
return self._bPrimary
def hPrimary(self, prob):
return self._hPrimary
def jPrimary(self, prob):
return self._jPrimary
class RawVec_m(BaseSrc):
"""
@@ -85,32 +70,16 @@ class RawVec_m(BaseSrc):
:param rxList: receiver list
"""
def __init__(self, rxList, freq, S_m, integrate = True, ePrimary=None, bPrimary=None, hPrimary=None, jPrimary=None):
def __init__(self, rxList, freq, S_m, integrate = True): #ePrimary=Zero(), bPrimary=Zero(), hPrimary=Zero(), jPrimary=Zero()):
self._S_m = np.array(S_m,dtype=complex)
self.freq = float(freq)
self.integrate = integrate
self._ePrimary = np.array(ePrimary,dtype=complex)
self._bPrimary = np.array(bPrimary,dtype=complex)
self._hPrimary = np.array(hPrimary,dtype=complex)
self._jPrimary = np.array(jPrimary,dtype=complex)
BaseSrc.__init__(self, rxList)
def S_m(self, prob):
return self._S_m
def ePrimary(self, prob):
return self._ePrimary
def bPrimary(self, prob):
return self._bPrimary
def hPrimary(self, prob):
return self._hPrimary
def jPrimary(self, prob):
return self._jPrimary
class RawVec(BaseSrc):
"""
@@ -192,7 +161,7 @@ class MagDipole(BaseSrc):
def S_e(self, prob):
if all(np.r_[self.mu] == np.r_[prob.curModel.mu]):
return None
return Zero()
else:
eqLocs = prob._eqLocs
@@ -261,7 +230,7 @@ class MagDipole_Bfield(BaseSrc):
def S_e(self, prob):
if all(np.r_[self.mu] == np.r_[prob.curModel.mu]):
return None
return Zero()
else:
eqLocs = prob._eqLocs
@@ -329,7 +298,7 @@ class CircularLoop(BaseSrc):
def S_e(self, prob):
if all(np.r_[self.mu] == np.r_[prob.curModel.mu]):
return None
return Zero()
else:
eqLocs = prob._eqLocs
+2
View File
@@ -1,8 +1,10 @@
import SimPEG
from SimPEG.EM.Utils import *
from scipy.constants import mu_0
from SimPEG.Utils import Zero, Identity
import SrcFDEM as Src
####################################################
# Receivers
####################################################
+8
View File
@@ -37,12 +37,20 @@ def SolverWrapD(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6):
if len(b.shape) == 1 or b.shape[1] == 1:
b = b.flatten()
# Just one RHS
if b.dtype is np.dtype('O'):
b = b.astype(type(b[0]))
if factorize:
X = self.solver.solve(b, **self.kwargs)
else:
X = fun(self.A, b, **self.kwargs)
else: # Multiple RHSs
if b.dtype is np.dtype('O'):
b = b.astype(type(b[0,0]))
X = np.empty_like(b)
for i in range(b.shape[1]):
if factorize:
X[:,i] = self.solver.solve(b[:,i])
+1 -4
View File
@@ -3,10 +3,7 @@ import time
import numpy as np
from functools import wraps
class SimPEGMetaClass(type):
def __new__(cls, name, bases, attrs):
return super(SimPEGMetaClass, cls).__new__(cls, name, bases, attrs)
SimPEGMetaClass = type
def memProfileWrapper(towrap, *funNames):
"""
+2
View File
@@ -399,8 +399,10 @@ def diagEst(matFun, n, k=None, approach='Probing'):
class Zero(object):
def __add__(self, v):return v
def __radd__(self, v):return v
def __iadd__(self, v):return v
def __sub__(self, v):return -v
def __rsub__(self, v):return v
def __isub__(self, v):return v
def __mul__(self, v):return self
def __rmul__(self, v):return self
def __div__(self, v): return self
+7
View File
@@ -20,6 +20,13 @@ class Tests(unittest.TestCase):
assert 3*z == 0
assert z*3 == 0
assert z/3 == 0
a = 1
a += z
assert a == 1
a = 1
a += z
assert a == 1
self.assertRaises(ZeroDivisionError, lambda:3/z)
def test_mat_zero(self):