mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-01 23:16:25 +08:00
Working for mumps
This commit is contained in:
+3
-40
@@ -368,47 +368,10 @@ class ComboMap(IdentityMap):
|
||||
mi = map_i.transform(mi)
|
||||
return deriv
|
||||
|
||||
class ComplexMap(IdentityMap):
|
||||
"""docstring for ComplexMap
|
||||
|
||||
default nP is nC in the mesh times 2 [real, imag]
|
||||
|
||||
"""
|
||||
def __init__(self, mesh, nP=None):
|
||||
IdentityMap.__init__(self, mesh)
|
||||
if nP is not None:
|
||||
assert nP%2 == 0, 'nP must be even.'
|
||||
self._nP = nP or (self.mesh.nC * 2)
|
||||
|
||||
@property
|
||||
def nP(self):
|
||||
return self._nP
|
||||
|
||||
def transform(self, m):
|
||||
nC = self.mesh.nC
|
||||
return m[:nC] + m[nC:]*1j
|
||||
|
||||
def transformDeriv(self, m):
|
||||
nC = self.nP/2
|
||||
shp = (nC, nC*2)
|
||||
def fwd(v):
|
||||
return v[:nC] + v[nC:]*1j
|
||||
def adj(v):
|
||||
return np.r_[v.real,v.imag]
|
||||
return Utils.SimPEGLinearOperator(shp,fwd,adj)
|
||||
|
||||
transformInverse = transformDeriv
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from SimPEG import *
|
||||
# mesh = Mesh.TensorMesh([10,8])
|
||||
# combo = ComboMap(mesh, [ExpMap, Vertical1DMap])
|
||||
# m = combo.example()
|
||||
# print m.shape
|
||||
# print combo.test(np.arange(8))
|
||||
mesh = Mesh.TensorMesh([10,8])
|
||||
mapping = ComplexMap(mesh)
|
||||
m = mapping.example()
|
||||
combo = ComboMap(mesh, [ExpMap, Vertical1DMap])
|
||||
m = combo.example()
|
||||
print m.shape
|
||||
print mapping.test(m)
|
||||
print combo.test(np.arange(8))
|
||||
|
||||
+3
-24
@@ -57,7 +57,6 @@ class BaseTimeRx(BaseRx):
|
||||
"""SimPEG Receiver Object"""
|
||||
|
||||
times = None #: Times when the receivers were active.
|
||||
projTLoc = 'N'
|
||||
|
||||
def __init__(self, locs, times, rxType, **kwargs):
|
||||
self.times = times
|
||||
@@ -68,26 +67,6 @@ class BaseTimeRx(BaseRx):
|
||||
"""Number of data in the receiver."""
|
||||
return self.locs.shape[0] * len(self.times)
|
||||
|
||||
def getSpatialP(self, mesh):
|
||||
"""
|
||||
Returns the spatial projection matrix.
|
||||
|
||||
.. note::
|
||||
|
||||
This is not stored in memory, but is created on demand.
|
||||
"""
|
||||
return mesh.getInterpolationMat(self.locs, self.projGLoc)
|
||||
|
||||
def getTimeP(self, timeMesh):
|
||||
"""
|
||||
Returns the time projection matrix.
|
||||
|
||||
.. note::
|
||||
|
||||
This is not stored in memory, but is created on demand.
|
||||
"""
|
||||
return timeMesh.getInterpolationMat(self.times, self.projTLoc)
|
||||
|
||||
def getP(self, mesh, timeMesh):
|
||||
"""
|
||||
Returns the projection matrices as a
|
||||
@@ -96,13 +75,13 @@ class BaseTimeRx(BaseRx):
|
||||
|
||||
.. note::
|
||||
|
||||
Projection matrices are stored as a dictionary (mesh, timeMesh) if storeProjections is True
|
||||
Projection matrices are stored as a dictionary (mesh, timeMesh)
|
||||
"""
|
||||
if (mesh, timeMesh) in self._Ps:
|
||||
return self._Ps[(mesh, timeMesh)]
|
||||
|
||||
Ps = self.getSpatialP(mesh)
|
||||
Pt = self.getTimeP(timeMesh)
|
||||
Ps = mesh.getInterpolationMat(self.locs, self.projGLoc)
|
||||
Pt = timeMesh.getInterpolationMat(self.times, 'N')
|
||||
P = sp.kron(Pt, Ps)
|
||||
|
||||
if self.storeProjections:
|
||||
|
||||
@@ -2,9 +2,10 @@ import numpy as np
|
||||
from matutils import mkvc
|
||||
import warnings
|
||||
|
||||
def DSolverWrap(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6):
|
||||
def DSolverWrap(fun, factorize=True, destroy = False, checkAccuracy=True, accuracyTol=1e-6):
|
||||
|
||||
def __init__(self, A, **kwargs):
|
||||
|
||||
self.A = A.tocsc()
|
||||
self.kwargs = kwargs
|
||||
if factorize:
|
||||
@@ -34,7 +35,13 @@ def DSolverWrap(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6):
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
return X
|
||||
|
||||
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve})
|
||||
def clean(self):
|
||||
if destroy == True:
|
||||
return self.solver.clean()
|
||||
else:
|
||||
return True
|
||||
|
||||
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve, "clean": clean})
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -330,14 +330,3 @@ def invPropertyTensor(M, tensor, returnMatrix=False):
|
||||
return makePropertyTensor(M, T)
|
||||
|
||||
return T
|
||||
|
||||
|
||||
|
||||
from scipy.sparse.linalg import LinearOperator
|
||||
|
||||
class SimPEGLinearOperator(LinearOperator):
|
||||
"""Extends scipy.sparse.linalg.LinearOperator to have a .T function."""
|
||||
@property
|
||||
def T(self):
|
||||
return self.__class__((self.shape[1],self.shape[0]),self.rmatvec,rmatvec=self.matvec,matmat=self.matmat)
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from scipy import sparse as sp
|
||||
from matutils import mkvc, ndgrid, sub2ind, sdiag
|
||||
from codeutils import asArray_N_x_Dim
|
||||
from codeutils import isScalar
|
||||
import SimPEG
|
||||
|
||||
|
||||
def exampleLrmGrid(nC, exType):
|
||||
assert type(nC) == list, "nC must be a list containing the number of nodes"
|
||||
@@ -131,8 +133,7 @@ def readUBCTensorMesh(fileName):
|
||||
y0 = mesh[1][1]
|
||||
z0 = -(hz.sum()-mesh[1][2])
|
||||
|
||||
from SimPEG import Mesh
|
||||
mesh3D = Mesh.TensorMesh([hx, hy, hz], np.r_[x0, y0, z0])
|
||||
mesh3D = SimPEG.Mesh.TensorMesh([hx, hy, hz], np.r_[x0, y0, z0])
|
||||
|
||||
return mesh3D
|
||||
|
||||
@@ -147,7 +148,7 @@ def readUBCTensorModel(fileName, mesh):
|
||||
model = np.reshape(model, (mesh.nCz, mesh.nCx, mesh.nCy), order = 'F')
|
||||
model = model[::-1,:,:]
|
||||
model = np.transpose(model, (1, 2, 0))
|
||||
model = mkvc(model)
|
||||
model = SimPEG.Utils.mkvc(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user