diff --git a/SimPEG/BaseMesh.py b/SimPEG/BaseMesh.py index a197b8a4..9d8982e6 100644 --- a/SimPEG/BaseMesh.py +++ b/SimPEG/BaseMesh.py @@ -1,4 +1,5 @@ import numpy as np +from utils import mkvc class BaseMesh(object): @@ -52,6 +53,113 @@ class BaseMesh(object): return locals() x0 = property(**x0()) + def r(self, x, xType='CC', outType='CC', format='V'): + """ + Mesh.r is a quick reshape command that will do the best it can at giving you what you want. + + For example, you have a face variable, and you want the x component of it reshaped to a 3D matrix. + + Mesh.r can fulfil your dreams... + + mesh.r(V, 'F', 'Fx', 'M') + | | | { How: 'M' or ['V'] for a matrix (ndgrid style) or a vector (n x dim) } + | | { What you want: ['CC'], 'N', 'F', 'Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', or 'Ez' } + | { What is it: ['CC'], 'N', 'F', 'Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', or 'Ez' } + { The input: as a list or ndarray } + + + For example: + + Xex, Yex, Zex = r(mesh.gridEx, 'Ex', 'Ex', 'M') # Separates each component of the Ex grid into 3 matrices + + XedgeVector = r(edgeVector, 'E', 'Ex', 'V') # Given an edge vector, this will return just the part on the x edges as a vector + + eX, eY, eZ = r(edgeVector, 'E', 'E', 'V') # Separates each component of the edgeVector into 3 vectors + """ + + assert (type(x) == list or type(x) == np.ndarray), "x must be either a list or a ndarray" + assert xType in ['CC', 'N', 'F', 'Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', 'Ez'], "xType must be either 'CC', 'N', 'F', 'Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', or 'Ez'" + assert outType in ['CC', 'N', 'F', 'Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', 'Ez'], "outType must be either 'CC', 'N', 'F', Fx', 'Fy', 'Fz', 'E', 'Ex', 'Ey', or 'Ez'" + assert format in ['M', 'V'], "format must be either 'M' or 'V'" + assert outType[:len(xType)] == xType, "You cannot change types when reshaping." + assert xType in outType, 'You cannot change type of components.' + if type(x) == list: + for i, xi in enumerate(x): + assert type(x) == np.ndarray, "x[%i] must be a numpy array" % i + assert xi.size == x[0].size, "Number of elements in list must not change." + + x_array = np.ones((x.size, len(x))) + # Unwrap it and put it in a np array + for i, xi in enumerate(x): + x_array[:, i] = mkvc(xi) + x = x_array + + assert type(x) == np.ndarray, "x must be a numpy array" + + x = x[:] # make a copy. + xTypeIsFExyz = len(xType) > 1 and xType[0] in ['F', 'E'] and xType[1] in ['x', 'y', 'z'] + + def outKernal(xx, nn): + """Returns xx as either a matrix (shape == nn) or a vector.""" + if format == 'M': + return xx.reshape(nn, order='F') + elif format == 'V': + return mkvc(xx) + + def switchKernal(xx): + """Switches over the different options.""" + if xType in ['CC', 'N']: + nn = (self.n) if xType == 'CC' else (self.n+1) + assert xx.size == np.prod(nn), "Number of elements must not change." + return outKernal(xx, nn) + elif xType in ['F', 'E']: + # This will only deal with components of fields, not full 'F' or 'E' + xx = mkvc(xx) # unwrap it in case it is a matrix + nn = self.nF if xType == 'F' else self.nE + nn = np.r_[0, nn] + + nx = [0, 0, 0] + nx[0] = self.nFx if xType == 'F' else self.nEx + nx[1] = self.nFy if xType == 'F' else self.nEy + nx[2] = self.nFz if xType == 'F' else self.nEz + + for dim, dimName in enumerate(['x', 'y', 'z']): + if dimName in outType: + assert self.dim > dim, ("Dimensions of mesh not great enough for %s%s", (xType, dimName)) + assert xx.size == np.sum(nn), 'Vector is not the right size.' + start = np.sum(nn[:dim+1]) + end = np.sum(nn[:dim+2]) + return outKernal(xx[start:end], nx[dim]) + elif xTypeIsFExyz: + # This will deal with partial components (x, y or z) lying on edges or faces + if 'x' in xType: + nn = self.nFx if 'F' in xType else self.nEx + elif 'y' in xType: + nn = self.nFy if 'F' in xType else self.nEy + elif 'z' in xType: + nn = self.nFz if 'F' in xType else self.nEz + assert xx.size == np.prod(nn), 'Vector is not the right size.' + return outKernal(xx, nn) + + # Check if we are dealing with a vector quantity + isVectorQuantity = len(x.shape) == 2 and x.shape[1] == self.dim + + if outType in ['F', 'E']: + assert ~isVectorQuantity, 'Not sure what to do with a vector vector quantity..' + outTypeCopy = outType + out = () + for ii, dirName in enumerate(['x', 'y', 'z'][:self.dim]): + outType = outTypeCopy + dirName + out += (switchKernal(x),) + return out + elif isVectorQuantity: + out = () + for ii in range(x.shape[1]): + out += (switchKernal(x[:, ii]),) + return out + else: + return switchKernal(x) + def n(): doc = "Number of Cells in each dimension (array of integers)" fget = lambda self: self._n diff --git a/SimPEG/tests/test_basemesh.py b/SimPEG/tests/test_basemesh.py index 9aece19b..60373011 100644 --- a/SimPEG/tests/test_basemesh.py +++ b/SimPEG/tests/test_basemesh.py @@ -44,6 +44,74 @@ class TestBaseMesh(unittest.TestCase): self.assertTrue(np.all([c, f, e])) + def test_mesh_r_E_V(self): + ex = np.ones(self.mesh.nE[0]) + ey = np.ones(self.mesh.nE[1])*2 + ez = np.ones(self.mesh.nE[2])*3 + e = np.r_[ex, ey, ez] + tex = self.mesh.r(e, 'E', 'Ex', 'V') + tey = self.mesh.r(e, 'E', 'Ey', 'V') + tez = self.mesh.r(e, 'E', 'Ez', 'V') + self.assertTrue(np.all(tex == ex)) + self.assertTrue(np.all(tey == ey)) + self.assertTrue(np.all(tez == ez)) + tex, tey, tez = self.mesh.r(e, 'E', 'E', 'V') + self.assertTrue(np.all(tex == ex)) + self.assertTrue(np.all(tey == ey)) + self.assertTrue(np.all(tez == ez)) + + def test_mesh_r_F_V(self): + fx = np.ones(self.mesh.nF[0]) + fy = np.ones(self.mesh.nF[1])*2 + fz = np.ones(self.mesh.nF[2])*3 + f = np.r_[fx, fy, fz] + tfx = self.mesh.r(f, 'F', 'Fx', 'V') + tfy = self.mesh.r(f, 'F', 'Fy', 'V') + tfz = self.mesh.r(f, 'F', 'Fz', 'V') + self.assertTrue(np.all(tfx == fx)) + self.assertTrue(np.all(tfy == fy)) + self.assertTrue(np.all(tfz == fz)) + tfx, tfy, tfz = self.mesh.r(f, 'F', 'F', 'V') + self.assertTrue(np.all(tfx == fx)) + self.assertTrue(np.all(tfy == fy)) + self.assertTrue(np.all(tfz == fz)) + + def test_mesh_r_E_M(self): + g = np.ones((np.prod(self.mesh.nEx), 3)) + g[:, 1] = 2 + g[:, 2] = 3 + Xex, Yex, Zex = self.mesh.r(g, 'Ex', 'Ex', 'M') + self.assertTrue(np.all(Xex.shape == self.mesh.nEx)) + self.assertTrue(np.all(Yex.shape == self.mesh.nEx)) + self.assertTrue(np.all(Zex.shape == self.mesh.nEx)) + self.assertTrue(np.all(Xex == 1)) + self.assertTrue(np.all(Yex == 2)) + self.assertTrue(np.all(Zex == 3)) + + def test_mesh_r_F_M(self): + g = np.ones((np.prod(self.mesh.nFx), 3)) + g[:, 1] = 2 + g[:, 2] = 3 + Xfx, Yfx, Zfx = self.mesh.r(g, 'Fx', 'Fx', 'M') + self.assertTrue(np.all(Xfx.shape == self.mesh.nFx)) + self.assertTrue(np.all(Yfx.shape == self.mesh.nFx)) + self.assertTrue(np.all(Zfx.shape == self.mesh.nFx)) + self.assertTrue(np.all(Xfx == 1)) + self.assertTrue(np.all(Yfx == 2)) + self.assertTrue(np.all(Zfx == 3)) + + def test_mesh_r_CC_M(self): + g = np.ones((self.mesh.nC, 3)) + g[:, 1] = 2 + g[:, 2] = 3 + Xc, Yc, Zc = self.mesh.r(g, 'CC', 'CC', 'M') + self.assertTrue(np.all(Xc.shape == self.mesh.n)) + self.assertTrue(np.all(Yc.shape == self.mesh.n)) + self.assertTrue(np.all(Zc.shape == self.mesh.n)) + self.assertTrue(np.all(Xc == 1)) + self.assertTrue(np.all(Yc == 2)) + self.assertTrue(np.all(Zc == 3)) + class TestMeshNumbers2D(unittest.TestCase): @@ -84,5 +152,58 @@ class TestMeshNumbers2D(unittest.TestCase): self.assertTrue(np.all([c, f, e])) + def test_mesh_r_E_V(self): + ex = np.ones(self.mesh.nE[0]) + ey = np.ones(self.mesh.nE[1])*2 + e = np.r_[ex, ey] + tex = self.mesh.r(e, 'E', 'Ex', 'V') + tey = self.mesh.r(e, 'E', 'Ey', 'V') + self.assertTrue(np.all(tex == ex)) + self.assertTrue(np.all(tey == ey)) + tex, tey = self.mesh.r(e, 'E', 'E', 'V') + self.assertTrue(np.all(tex == ex)) + self.assertTrue(np.all(tey == ey)) + self.assertRaises(AssertionError, self.mesh.r, e, 'E', 'Ez', 'V') + + def test_mesh_r_F_V(self): + fx = np.ones(self.mesh.nF[0]) + fy = np.ones(self.mesh.nF[1])*2 + f = np.r_[fx, fy] + tfx = self.mesh.r(f, 'F', 'Fx', 'V') + tfy = self.mesh.r(f, 'F', 'Fy', 'V') + self.assertTrue(np.all(tfx == fx)) + self.assertTrue(np.all(tfy == fy)) + tfx, tfy = self.mesh.r(f, 'F', 'F', 'V') + self.assertTrue(np.all(tfx == fx)) + self.assertTrue(np.all(tfy == fy)) + self.assertRaises(AssertionError, self.mesh.r, f, 'F', 'Fz', 'V') + + def test_mesh_r_E_M(self): + g = np.ones((np.prod(self.mesh.nEx), 2)) + g[:, 1] = 2 + Xex, Yex = self.mesh.r(g, 'Ex', 'Ex', 'M') + self.assertTrue(np.all(Xex.shape == self.mesh.nEx)) + self.assertTrue(np.all(Yex.shape == self.mesh.nEx)) + self.assertTrue(np.all(Xex == 1)) + self.assertTrue(np.all(Yex == 2)) + + def test_mesh_r_F_M(self): + g = np.ones((np.prod(self.mesh.nFx), 2)) + g[:, 1] = 2 + Xfx, Yfx = self.mesh.r(g, 'Fx', 'Fx', 'M') + self.assertTrue(np.all(Xfx.shape == self.mesh.nFx)) + self.assertTrue(np.all(Yfx.shape == self.mesh.nFx)) + self.assertTrue(np.all(Xfx == 1)) + self.assertTrue(np.all(Yfx == 2)) + + def test_mesh_r_CC_M(self): + g = np.ones((self.mesh.nC, 2)) + g[:, 1] = 2 + Xc, Yc = self.mesh.r(g, 'CC', 'CC', 'M') + self.assertTrue(np.all(Xc.shape == self.mesh.n)) + self.assertTrue(np.all(Yc.shape == self.mesh.n)) + self.assertTrue(np.all(Xc == 1)) + self.assertTrue(np.all(Yc == 2)) + if __name__ == '__main__': unittest.main()