diff --git a/SimPEG/mesh/QuadTreeMesh.py b/SimPEG/Mesh/TreeMesh.py similarity index 62% rename from SimPEG/mesh/QuadTreeMesh.py rename to SimPEG/Mesh/TreeMesh.py index 1ac0f5eb..34519225 100644 --- a/SimPEG/mesh/QuadTreeMesh.py +++ b/SimPEG/Mesh/TreeMesh.py @@ -1,4 +1,4 @@ -from SimPEG import np, sp, utils +from SimPEG import np, sp, Utils, Solver import matplotlib.pyplot as plt import matplotlib.colors as colors import matplotlib.cm as cmx @@ -35,14 +35,16 @@ class TreeFace(object): self.mesh = mesh self.children = None self.numFace = None + self.x0 = np.array(x0, dtype=float) self.faceType = faceType - self.sz = sz + self.sz = np.array(sz, dtype=float) self.dim = dim self.depth = depth mesh.faces.add(self) - if faceType is 'x': self.mesh.faceX.add(self) - elif faceType is 'y': self.mesh.faceY.add(self) + if faceType is 'x': self.mesh.facesX.add(self) + elif faceType is 'y': self.mesh.facesY.add(self) + elif faceType is 'z': self.mesh.facesZ.add(self) self.tangent = np.zeros(dim) self.tangent[1 if faceType is 'x' else 0] = 1 self.normal = np.zeros(dim) @@ -53,12 +55,15 @@ class TreeFace(object): @property def index(self): + if not self.mesh.isNumbered: raise Exception('Mesh is not numbered.') if self.isleaf: return np.r_[self.numFace] return np.concatenate([face.index for face in self.children]) - @property - def area(self): return self.sz + @property + def area(self): + """area of the face""" + return self.sz.prod() def refine(self): if not self.isleaf: return @@ -68,32 +73,30 @@ class TreeFace(object): # Create refined x0's x0r_0 = self.x0 x0r_1 = self.x0+0.5*self.tangent*self.sz - self.children[0] = TreeFace(self.mesh, x0=x0r_0, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1,parent=self) - self.children[1] = TreeFace(self.mesh, x0=x0r_1, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1,parent=self) + self.children[0] = TreeFace(self.mesh, x0=x0r_0, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1, parent=self) + self.children[1] = TreeFace(self.mesh, x0=x0r_1, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1, parent=self) self.mesh.faces.remove(self) if self.faceType is 'x': - self.mesh.faceX.remove(self) + self.mesh.facesX.remove(self) elif self.faceType is 'y': - self.mesh.faceY.remove(self) - + self.mesh.facesY.remove(self) def viz(self, ax, text=True): if not self.isleaf: return ax.plot(np.r_[self.x0[0],self.x0[0]+self.tangent[0]*self.sz], np.r_[self.x0[1], self.x0[1]+self.tangent[1]*self.sz],'r-') if text: ax.text(self.x0[0]+0.5*self.tangent[0]*self.sz, self.x0[1]+0.5*self.tangent[1]*self.sz,self.numFace) + @property + def center(self): + return self.x0 + 0.5*self.tangent*self.sz class TreeNode(object): """docstring for TreeNode""" - def __init__(self, mesh, x0=[0,0], dim=2, depth=0, sz=[1,1], parent=None, fXm=None, fXp=None, fYm=None, fYp=None): + children = None #: + numCell = None - fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='x', dim=dim, sz=sz[1], depth=depth, parent=parent) - fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] ], faceType='x', dim=dim, sz=sz[1], depth=depth, parent=parent) - fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='y', dim=dim, sz=sz[0], depth=depth, parent=parent) - fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1]], faceType='y', dim=dim, sz=sz[0], depth=depth, parent=parent) - - self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp} + def __init__(self, mesh, x0=[0,0], dim=2, depth=0, sz=[1,1], parent=None, fXm=None, fXp=None, fYm=None, fYp=None, fZm=None, fZp=None): self.mesh = mesh self.x0 = np.array(x0, dtype=float) @@ -101,8 +104,22 @@ class TreeNode(object): self.dim = dim self.depth = depth self.parent = parent - self.children = None - self.numCell = None + if dim == 2: + fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='x', dim=dim, sz=np.r_[sz[1]], depth=depth, parent=parent) + fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] ], faceType='x', dim=dim, sz=np.r_[sz[1]], depth=depth, parent=parent) + fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='y', dim=dim, sz=np.r_[sz[0]], depth=depth, parent=parent) + fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1]], faceType='y', dim=dim, sz=np.r_[sz[0]], depth=depth, parent=parent) + self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp} + + elif dim == 3: + fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='x', dim=dim, sz=np.r_[sz[1], sz[2]], depth=depth, parent=parent) + fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] , x0[2] ], faceType='x', dim=dim, sz=np.r_[sz[1], sz[2]], depth=depth, parent=parent) + fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='y', dim=dim, sz=np.r_[sz[0], sz[2]], depth=depth, parent=parent) + fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1], x0[2] ], faceType='y', dim=dim, sz=np.r_[sz[0], sz[2]], depth=depth, parent=parent) + fZm = fZm if fZm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='z', dim=dim, sz=np.r_[sz[0], sz[1]], depth=depth, parent=parent) + fZp = fZp if fZp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2]+sz[2]], faceType='z', dim=dim, sz=np.r_[sz[0], sz[1]], depth=depth, parent=parent) + self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp, "fZm":fZm, "fZp":fZp} + mesh.cells.add(self) @property @@ -120,13 +137,17 @@ class TreeNode(object): return np.max([node.branchdepth for node in self.children.flatten('F')]) @property - def gridCC(self): return self.x0 + 0.5*self.sz + def center(self): return self.x0 + 0.5*self.sz def refine(self, function=None): + if self.dim == 2: + return self._refine2D(function=function) + + def _refine2D(self, function=None): if not self.isleaf and function is None: return if function is not None: - do = function(self.gridCC) > self.depth + do = function(self.center) > self.depth if not do: return self.mesh.isNumbered = False @@ -187,20 +208,32 @@ class TreeNode(object): if not self.isleaf: return x0, sz = self.x0, self.sz ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=color, edgecolor='k')) - if text: ax.text(self.gridCC[0],self.gridCC[1],self.numCell) + if text: ax.text(self.center[0],self.center[1],self.numCell) -class QuadTreeMesh(object): - """docstring for QuadTreeMesh""" - def __init__(self, h, x0=[0,0]): - self.faces = set() - self.faceX = set() - self.faceY = set() - self.cells = set() - self.x0 = np.array(x0,dtype=float) +class TreeMesh(object): + """TreeMesh""" + def __init__(self, h, x0=None): + + assert type(h) is list, 'h must be a list' + self.h = h + if x0 is None: + x0 = np.zeros(self.dim) + else: + assert type(x0) in [list, np.ndarray], 'x0 must be a numpy array or a list' + assert len(x0) == self.dim, 'x0 must have the same dimensions as the mesh' + self.x0 = np.array(x0, dtype=float) + + # set the sets for holding the faces and cells + self.cells = set() + self.faces = set() + self.facesX = set() + self.facesY = set() + if self.dim == 3: self.facesZ = set() + self.children = np.empty([hi.size for hi in h],dtype=TreeNode) for i in range(h[0].size): for j in range(h[1].size): @@ -210,7 +243,7 @@ class QuadTreeMesh(object): x0j = (np.r_[x0[1], h[1][:j]]).sum() self.children[i][j] = TreeNode(self, x0=[x0i, x0j], dim=len(h), depth=0, sz=[h[0][i], h[1][j]], fXm=fXm, fYm=fYm) - isNumbered = utils.dependentProperty('_isNumbered', False, ['_faceDiv'], 'Setting this to False will delete all operators.') + isNumbered = Utils.dependentProperty('_isNumbered', False, ['_faceDiv'], 'Setting this to False will delete all operators.') @property def branchdepth(self): @@ -223,41 +256,81 @@ class QuadTreeMesh(object): def number(self): if self.isNumbered: return - sortedCells = sorted(M.cells,key=SortByX0()) - sortedFaceX = sorted(M.faceX,key=SortByX0()) - sortedFaceY = sorted(M.faceY,key=SortByX0()) - nFx = len(sortedFaceX) - for i, sc in enumerate(sortedCells): sc.numCell = i - for i, sfx in enumerate(sortedFaceX): sfx.numFace = i - for i, sfy in enumerate(sortedFaceY): sfy.numFace = i + nFx + self.sortedCells = sorted(self.cells,key=SortByX0()) + for i, sc in enumerate(self.sortedCells): sc.numCell = i - self.sortedCells = sortedCells - self.sortedFaceX = sortedFaceX - self.sortedFaceY = sortedFaceY + self.sortedFaceX = sorted(self.facesX,key=SortByX0()) + for i, sfx in enumerate(self.sortedFaceX): sfx.numFace = i + + self.sortedFaceY = sorted(self.facesY,key=SortByX0()) + for i, sfy in enumerate(self.sortedFaceY): sfy.numFace = i + self.nFx + + if self.dim == 3: + self.sortedFaceZ = sorted(self.facesZ,key=SortByX0()) + for i, sfz in enumerate(self.sortedFaceZ): sfz.numFace = i + self.nFx + self.nFy self.isNumbered = True + @property + def dim(self): return len(self.h) - @property - def dim(self): return len(self.x0) @property def nC(self): return len(self.cells) + @property def nF(self): return len(self.faces) + @property def nFx(self): return len(self.facesX) + @property def nFy(self): return len(self.facesY) + @property + def nFz(self): return len(self.facesZ) + + @property + def nE(self): return len(self.faces) + + @property + def nEx(self): + if self.dim == 2: + return len(self.facesY) + else: raise NotImplementedError('nEx') + + @property + def nEy(self): + if self.dim == 2: + return len(self.facesX) + else: raise NotImplementedError('nEy') + @property def gridCC(self): if getattr(self, '_gridCC', None) is None: self.number() self._gridCC = np.empty((self.nC,self.dim)) for ii, cell in enumerate(self.sortedCells): - self._gridCC[ii,:] = cell.gridCC + self._gridCC[ii,:] = cell.center return self._gridCC + @property + def gridFx(self): + if getattr(self, '_gridFx', None) is None: + self.number() + self._gridFx = np.empty((self.nFx,self.dim)) + for ii, face in enumerate(self.sortedFaceX): + self._gridFx[ii,:] = face.center + return self._gridFx + + @property + def gridFy(self): + if getattr(self, '_gridFy', None) is None: + self.number() + self._gridFy = np.empty((self.nFy,self.dim)) + for ii, face in enumerate(self.sortedFaceY): + self._gridFy[ii,:] = face.center + return self._gridFy + @property def vol(self): self.number() @@ -280,20 +353,27 @@ class QuadTreeMesh(object): VOL = self.vol D = sp.csr_matrix((V,(I,J)), shape=(M.nC, M.nF)) S = self.area - self._faceDiv = utils.sdiag(1/VOL)*D*utils.sdiag(S) + self._faceDiv = Utils.sdiag(1/VOL)*D*Utils.sdiag(S) return self._faceDiv - def plotGrid(self, ax=None, text=True, plotC=True, plotF=False): + def plotGrid(self, ax=None, text=True, plotC=True, plotF=False, showIt=False): if ax is None: ax = plt.subplot(111) if plotC: [node.viz(ax, text=text) for node in self.cells] if plotF: [node.viz(ax, text=text) for node in self.faces] ax.set_xlim((self.x0[0], self.h[0].sum())) ax.set_ylim((self.x0[1], self.h[1].sum())) + if showIt: plt.show() - def plotImage(self, I, ax=None): + def plotImage(self, I, ax=None, showIt=True): + if self.dim == 2: + self._plotImage2D(I, ax=ax, showIt=showIt) + elif self.dim == 3: + raise NotImplementedError('3D visualization is not yet implemented.') + + def _plotImage2D(self, I, ax=None, showIt=True): if ax is None: ax = plt.subplot(111) jet = cm = plt.get_cmap('jet') cNorm = colors.Normalize(vmin=I.min(), vmax=I.max()) @@ -304,11 +384,12 @@ class QuadTreeMesh(object): node.viz(ax=ax, color=scalarMap.to_rgba(I[ii])) scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots plt.colorbar(scalarMap) + if showIt: plt.show() if __name__ == '__main__': - M = QuadTreeMesh([np.ones(x) for x in [4,10]]) + M = TreeMesh([np.ones(x) for x in [4,10]]) def function(xc): r = xc - np.r_[2.,6.] @@ -330,7 +411,7 @@ if __name__ == '__main__': q = np.zeros(M.nC) q[208] = -1.0 q[291] = 1.0 - b = utils.Solver(-DIV*DIV.T).solve(q) + b = Solver(-DIV*DIV.T).solve(q) plt.figure() M.plotImage(b) # plt.gca().invert_yaxis() diff --git a/SimPEG/Mesh/__init__.py b/SimPEG/Mesh/__init__.py index 66ee88c8..3da22a01 100644 --- a/SimPEG/Mesh/__init__.py +++ b/SimPEG/Mesh/__init__.py @@ -1,6 +1,6 @@ from Cyl1DMesh import Cyl1DMesh from TensorMesh import TensorMesh -from QuadTreeMesh import QuadTreeMesh +from TreeMesh import TreeMesh from LogicallyOrthogonalMesh import LogicallyOrthogonalMesh from BaseMesh import BaseMesh from TensorView import TensorView diff --git a/SimPEG/tests/test_QuadTree.py b/SimPEG/tests/test_QuadTree.py index c085c7ae..2fef3fdd 100644 --- a/SimPEG/tests/test_QuadTree.py +++ b/SimPEG/tests/test_QuadTree.py @@ -1,19 +1,41 @@ -from SimPEG import mesh, np +from SimPEG import Mesh, np import unittest -class TestCheckDerivative(unittest.TestCase): +class TestQuadTreeMesh(unittest.TestCase): def setUp(self): - M = mesh.QuadTreeMesh([3,2],[1,2]) + M = Mesh.TreeMesh([np.ones(x) for x in [3,2]]) for ii in range(1): M.children[ii,ii].refine() self.M = M + M.number() + # M.plotGrid(showIt=True) def test_MeshSizes(self): - self.assertTrue(len(self.M.faces)==25) - self.assertTrue(len(self.M.cells)==9) + self.assertTrue(self.M.nC==9) + self.assertTrue(self.M.nF==25) + self.assertTrue(self.M.nFx==12) + self.assertTrue(self.M.nFy==13) + self.assertTrue(self.M.nE==25) + self.assertTrue(self.M.nEx==13) + self.assertTrue(self.M.nEy==12) + + def test_gridCC(self): + x = np.r_[0.25,0.75,1.5,2.5,0.25,0.75,0.5,1.5,2.5] + y = np.r_[0.25,0.25,0.5,0.5,0.75,0.75,1.5,1.5,1.5] + self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridCC).flatten()) == 0) + + def test_gridFx(self): + x = np.r_[0.0,0.5,1.0,2.0,3.0,0.0,0.5,1.0,0.0,1.0,2.0,3.0] + y = np.r_[0.25,0.25,0.25,0.5,0.5,0.75,0.75,0.75,1.5,1.5,1.5,1.5] + self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridFx).flatten()) == 0) + + def test_gridFy(self): + x = np.r_[0.25,0.75,1.5,2.5,0.25,0.75,0.25,0.75,1.5,2.5,0.5,1.5,2.5] + y = np.r_[0,0,0,0,0.5,0.5,1,1,1,1,2,2,2] + self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridFy).flatten()) == 0)