mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-28 00:30:01 +08:00
add a 2d to 3d map and test it!
This commit is contained in:
+58
-1
@@ -12,7 +12,8 @@ class IdentityMap(object):
|
||||
|
||||
mesh = None #: A SimPEG Mesh
|
||||
|
||||
def __init__(self, mesh):
|
||||
def __init__(self, mesh, **kwargs):
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
self.mesh = mesh
|
||||
|
||||
@property
|
||||
@@ -261,6 +262,62 @@ class Vertical1DMap(IdentityMap):
|
||||
), shape=(repNum, 1))
|
||||
return sp.kron(sp.identity(self.nP), repVec)
|
||||
|
||||
|
||||
class Map2Dto3D(IdentityMap):
|
||||
"""Map2Dto3D
|
||||
|
||||
Given a 2D vector, this will extend to the full
|
||||
3D model space.
|
||||
"""
|
||||
|
||||
normal = 'Y' #: The normal
|
||||
|
||||
def __init__(self, mesh, **kwargs):
|
||||
assert mesh.dim == 3, 'Only works for a 3D Mesh'
|
||||
IdentityMap.__init__(self, mesh, **kwargs)
|
||||
assert self.normal in ['X','Y','Z'], 'For now, only "Y" normal is supported'
|
||||
|
||||
@property
|
||||
def nP(self):
|
||||
"""Number of model properties.
|
||||
|
||||
The number of cells in the
|
||||
last dimension of the mesh."""
|
||||
if self.normal == 'Z':
|
||||
return self.mesh.nCx * self.mesh.nCy
|
||||
elif self.normal == 'Y':
|
||||
return self.mesh.nCx * self.mesh.nCz
|
||||
elif self.normal == 'X':
|
||||
return self.mesh.nCy * self.mesh.nCz
|
||||
|
||||
def _transform(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: numpy.array
|
||||
:return: transformed model
|
||||
"""
|
||||
m = Utils.mkvc(m)
|
||||
if self.normal == 'Z':
|
||||
return Utils.mkvc(m.reshape(self.mesh.vnC[[0,1]], order='F')[:,:,np.newaxis].repeat(self.mesh.nCz,axis=2))
|
||||
elif self.normal == 'Y':
|
||||
return Utils.mkvc(m.reshape(self.mesh.vnC[[0,2]], order='F')[:,np.newaxis,:].repeat(self.mesh.nCy,axis=1))
|
||||
elif self.normal == 'X':
|
||||
return Utils.mkvc(m.reshape(self.mesh.vnC[[1,2]], order='F')[np.newaxis,:,:].repeat(self.mesh.nCx,axis=0))
|
||||
|
||||
def deriv(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: scipy.csr_matrix
|
||||
:return: derivative of transformed model
|
||||
"""
|
||||
inds = self * np.arange(self.nP)
|
||||
nC, nP = self.mesh.nC, self.nP
|
||||
P = sp.csr_matrix(
|
||||
(np.ones(nC),
|
||||
(range(nC), inds)
|
||||
), shape=(nC, nP))
|
||||
return P
|
||||
|
||||
class Mesh2Mesh(IdentityMap):
|
||||
"""
|
||||
Takes a model on one mesh are translates it to another mesh.
|
||||
|
||||
@@ -13,9 +13,10 @@ class MapTests(unittest.TestCase):
|
||||
a = np.array([1, 1, 1])
|
||||
b = np.array([1, 2])
|
||||
self.mesh2 = Mesh.TensorMesh([a, b], x0=np.array([3, 5]))
|
||||
self.mesh3 = Mesh.TensorMesh([a, b, [3,4]], x0=np.array([3, 5, 2]))
|
||||
self.mesh22 = Mesh.TensorMesh([b, a], x0=np.array([3, 5]))
|
||||
|
||||
def test_transforms(self):
|
||||
def test_transforms2D(self):
|
||||
for M in dir(Maps):
|
||||
try:
|
||||
maps = getattr(Maps, M)(self.mesh2)
|
||||
@@ -24,6 +25,15 @@ class MapTests(unittest.TestCase):
|
||||
continue
|
||||
self.assertTrue(maps.test())
|
||||
|
||||
def test_transforms3D(self):
|
||||
for M in dir(Maps):
|
||||
try:
|
||||
maps = getattr(Maps, M)(self.mesh3)
|
||||
assert isinstance(maps, Maps.IdentityMap)
|
||||
except Exception, e:
|
||||
continue
|
||||
self.assertTrue(maps.test())
|
||||
|
||||
def test_Mesh2MeshMap(self):
|
||||
maps = Maps.Mesh2Mesh([self.mesh22, self.mesh2])
|
||||
self.assertTrue(maps.test())
|
||||
@@ -90,5 +100,34 @@ class MapTests(unittest.TestCase):
|
||||
self.assertRaises(ValueError, lambda: actMap * vertMap * expMap )
|
||||
|
||||
|
||||
def test_map2Dto3D_x(self):
|
||||
M2 = Mesh.TensorMesh([2,4])
|
||||
M3 = Mesh.TensorMesh([3,2,4])
|
||||
m = np.random.rand(M2.nC)
|
||||
m2to3 = Maps.Map2Dto3D(M3, normal='X')
|
||||
m = np.arange(m2to3.nP)
|
||||
self.assertTrue(m2to3.test())
|
||||
self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[0,:,:] ) == m))
|
||||
|
||||
|
||||
def test_map2Dto3D_y(self):
|
||||
M2 = Mesh.TensorMesh([3,4])
|
||||
M3 = Mesh.TensorMesh([3,2,4])
|
||||
m = np.random.rand(M2.nC)
|
||||
m2to3 = Maps.Map2Dto3D(M3, normal='Y')
|
||||
m = np.arange(m2to3.nP)
|
||||
self.assertTrue(m2to3.test())
|
||||
self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[:,0,:] ) == m))
|
||||
|
||||
def test_map2Dto3D_z(self):
|
||||
M2 = Mesh.TensorMesh([3,2])
|
||||
M3 = Mesh.TensorMesh([3,2,4])
|
||||
m = np.random.rand(M2.nC)
|
||||
m2to3 = Maps.Map2Dto3D(M3, normal='Z')
|
||||
m = np.arange(m2to3.nP)
|
||||
self.assertTrue(m2to3.test())
|
||||
self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[:,:,0] ) == m))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user