diff --git a/SimPEG/Maps.py b/SimPEG/Maps.py index b6af6f13..706e83f2 100644 --- a/SimPEG/Maps.py +++ b/SimPEG/Maps.py @@ -12,7 +12,8 @@ class IdentityMap(object): mesh = None #: A SimPEG Mesh - def __init__(self, mesh): + def __init__(self, mesh, **kwargs): + Utils.setKwargs(self, **kwargs) self.mesh = mesh @property @@ -261,6 +262,62 @@ class Vertical1DMap(IdentityMap): ), shape=(repNum, 1)) return sp.kron(sp.identity(self.nP), repVec) + +class Map2Dto3D(IdentityMap): + """Map2Dto3D + + Given a 2D vector, this will extend to the full + 3D model space. + """ + + normal = 'Y' #: The normal + + def __init__(self, mesh, **kwargs): + assert mesh.dim == 3, 'Only works for a 3D Mesh' + IdentityMap.__init__(self, mesh, **kwargs) + assert self.normal in ['X','Y','Z'], 'For now, only "Y" normal is supported' + + @property + def nP(self): + """Number of model properties. + + The number of cells in the + last dimension of the mesh.""" + if self.normal == 'Z': + return self.mesh.nCx * self.mesh.nCy + elif self.normal == 'Y': + return self.mesh.nCx * self.mesh.nCz + elif self.normal == 'X': + return self.mesh.nCy * self.mesh.nCz + + def _transform(self, m): + """ + :param numpy.array m: model + :rtype: numpy.array + :return: transformed model + """ + m = Utils.mkvc(m) + if self.normal == 'Z': + return Utils.mkvc(m.reshape(self.mesh.vnC[[0,1]], order='F')[:,:,np.newaxis].repeat(self.mesh.nCz,axis=2)) + elif self.normal == 'Y': + return Utils.mkvc(m.reshape(self.mesh.vnC[[0,2]], order='F')[:,np.newaxis,:].repeat(self.mesh.nCy,axis=1)) + elif self.normal == 'X': + return Utils.mkvc(m.reshape(self.mesh.vnC[[1,2]], order='F')[np.newaxis,:,:].repeat(self.mesh.nCx,axis=0)) + + def deriv(self, m): + """ + :param numpy.array m: model + :rtype: scipy.csr_matrix + :return: derivative of transformed model + """ + inds = self * np.arange(self.nP) + nC, nP = self.mesh.nC, self.nP + P = sp.csr_matrix( + (np.ones(nC), + (range(nC), inds) + ), shape=(nC, nP)) + return P + class Mesh2Mesh(IdentityMap): """ Takes a model on one mesh are translates it to another mesh. diff --git a/SimPEG/Tests/test_maps.py b/SimPEG/Tests/test_maps.py index 986fa500..6e595732 100644 --- a/SimPEG/Tests/test_maps.py +++ b/SimPEG/Tests/test_maps.py @@ -13,9 +13,10 @@ class MapTests(unittest.TestCase): a = np.array([1, 1, 1]) b = np.array([1, 2]) self.mesh2 = Mesh.TensorMesh([a, b], x0=np.array([3, 5])) + self.mesh3 = Mesh.TensorMesh([a, b, [3,4]], x0=np.array([3, 5, 2])) self.mesh22 = Mesh.TensorMesh([b, a], x0=np.array([3, 5])) - def test_transforms(self): + def test_transforms2D(self): for M in dir(Maps): try: maps = getattr(Maps, M)(self.mesh2) @@ -24,6 +25,15 @@ class MapTests(unittest.TestCase): continue self.assertTrue(maps.test()) + def test_transforms3D(self): + for M in dir(Maps): + try: + maps = getattr(Maps, M)(self.mesh3) + assert isinstance(maps, Maps.IdentityMap) + except Exception, e: + continue + self.assertTrue(maps.test()) + def test_Mesh2MeshMap(self): maps = Maps.Mesh2Mesh([self.mesh22, self.mesh2]) self.assertTrue(maps.test()) @@ -90,5 +100,34 @@ class MapTests(unittest.TestCase): self.assertRaises(ValueError, lambda: actMap * vertMap * expMap ) + def test_map2Dto3D_x(self): + M2 = Mesh.TensorMesh([2,4]) + M3 = Mesh.TensorMesh([3,2,4]) + m = np.random.rand(M2.nC) + m2to3 = Maps.Map2Dto3D(M3, normal='X') + m = np.arange(m2to3.nP) + self.assertTrue(m2to3.test()) + self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[0,:,:] ) == m)) + + + def test_map2Dto3D_y(self): + M2 = Mesh.TensorMesh([3,4]) + M3 = Mesh.TensorMesh([3,2,4]) + m = np.random.rand(M2.nC) + m2to3 = Maps.Map2Dto3D(M3, normal='Y') + m = np.arange(m2to3.nP) + self.assertTrue(m2to3.test()) + self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[:,0,:] ) == m)) + + def test_map2Dto3D_z(self): + M2 = Mesh.TensorMesh([3,2]) + M3 = Mesh.TensorMesh([3,2,4]) + m = np.random.rand(M2.nC) + m2to3 = Maps.Map2Dto3D(M3, normal='Z') + m = np.arange(m2to3.nP) + self.assertTrue(m2to3.test()) + self.assertTrue(np.all(Utils.mkvc( (m2to3 * m).reshape(M3.vnC,order='F')[:,:,0] ) == m)) + + if __name__ == '__main__': unittest.main()