diff --git a/SimPEG/Maps.py b/SimPEG/Maps.py index c5f1c1a5..72386695 100644 --- a/SimPEG/Maps.py +++ b/SimPEG/Maps.py @@ -369,29 +369,35 @@ class ComboMap(IdentityMap): return deriv class ComplexMap(IdentityMap): - """docstring for ComplexMap""" - def __init__(self, mesh): + """docstring for ComplexMap + + default nP is nC in the mesh times 2 [real, imag] + + """ + def __init__(self, mesh, nP=None): IdentityMap.__init__(self, mesh) + if nP is not None: + assert nP%2 == 0, 'nP must be even.' + self._nP = nP or (self.mesh.nC * 2) @property def nP(self): - return self.mesh.nC * 2 + return self._nP def transform(self, m): nC = self.mesh.nC return m[:nC] + m[nC:]*1j - def transformDeriv(self, m, v=None, adjoint=False): - nC = self.mesh.nC - if v is None and adjoint is False: - return sp.hstack((sp.identity(nC), np.identity(nC,dtype=complex)*1j)) - - assert v is not None, 'Must have a vector to multiply by.' - - if adjoint is False: - return sp.hstack((sp.identity(nC), np.identity(nC,dtype=complex)*1j)) * v - elif adjoint is True: + def transformDeriv(self, m): + nC = self.nP/2 + shp = (nC, nC*2) + def fwd(v): + return v[:nC] + v[nC:]*1j + def adj(v): return np.r_[v.real,v.imag] + return Utils.SimPEGLinearOperator(shp,fwd,adj) + + transformInverse = transformDeriv if __name__ == '__main__': diff --git a/SimPEG/Utils/matutils.py b/SimPEG/Utils/matutils.py index 2d3b87f7..9fedfce8 100644 --- a/SimPEG/Utils/matutils.py +++ b/SimPEG/Utils/matutils.py @@ -330,3 +330,14 @@ def invPropertyTensor(M, tensor, returnMatrix=False): return makePropertyTensor(M, T) return T + + + +from scipy.sparse.linalg import LinearOperator + +class SimPEGLinearOperator(LinearOperator): + """Extends scipy.sparse.linalg.LinearOperator to have a .T function.""" + @property + def T(self): + return self.__class__((self.shape[1],self.shape[0]),self.rmatvec,rmatvec=self.matvec,matmat=self.matmat) +