mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 07:18:35 +08:00
Grid updates for TreeMesh, initial generalizations for OcTree
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from SimPEG import np, sp, utils
|
||||
from SimPEG import np, sp, Utils, Solver
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as colors
|
||||
import matplotlib.cm as cmx
|
||||
@@ -35,14 +35,16 @@ class TreeFace(object):
|
||||
self.mesh = mesh
|
||||
self.children = None
|
||||
self.numFace = None
|
||||
|
||||
self.x0 = np.array(x0, dtype=float)
|
||||
self.faceType = faceType
|
||||
self.sz = sz
|
||||
self.sz = np.array(sz, dtype=float)
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
mesh.faces.add(self)
|
||||
if faceType is 'x': self.mesh.faceX.add(self)
|
||||
elif faceType is 'y': self.mesh.faceY.add(self)
|
||||
if faceType is 'x': self.mesh.facesX.add(self)
|
||||
elif faceType is 'y': self.mesh.facesY.add(self)
|
||||
elif faceType is 'z': self.mesh.facesZ.add(self)
|
||||
self.tangent = np.zeros(dim)
|
||||
self.tangent[1 if faceType is 'x' else 0] = 1
|
||||
self.normal = np.zeros(dim)
|
||||
@@ -53,12 +55,15 @@ class TreeFace(object):
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
if not self.mesh.isNumbered: raise Exception('Mesh is not numbered.')
|
||||
if self.isleaf: return np.r_[self.numFace]
|
||||
return np.concatenate([face.index for face in self.children])
|
||||
|
||||
@property
|
||||
def area(self): return self.sz
|
||||
|
||||
@property
|
||||
def area(self):
|
||||
"""area of the face"""
|
||||
return self.sz.prod()
|
||||
|
||||
def refine(self):
|
||||
if not self.isleaf: return
|
||||
@@ -68,32 +73,30 @@ class TreeFace(object):
|
||||
# Create refined x0's
|
||||
x0r_0 = self.x0
|
||||
x0r_1 = self.x0+0.5*self.tangent*self.sz
|
||||
self.children[0] = TreeFace(self.mesh, x0=x0r_0, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1,parent=self)
|
||||
self.children[1] = TreeFace(self.mesh, x0=x0r_1, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1,parent=self)
|
||||
self.children[0] = TreeFace(self.mesh, x0=x0r_0, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1, parent=self)
|
||||
self.children[1] = TreeFace(self.mesh, x0=x0r_1, faceType=self.faceType, dim=self.dim, sz=0.5*self.sz, depth=self.depth+1, parent=self)
|
||||
self.mesh.faces.remove(self)
|
||||
if self.faceType is 'x':
|
||||
self.mesh.faceX.remove(self)
|
||||
self.mesh.facesX.remove(self)
|
||||
elif self.faceType is 'y':
|
||||
self.mesh.faceY.remove(self)
|
||||
|
||||
self.mesh.facesY.remove(self)
|
||||
|
||||
def viz(self, ax, text=True):
|
||||
if not self.isleaf: return
|
||||
ax.plot(np.r_[self.x0[0],self.x0[0]+self.tangent[0]*self.sz], np.r_[self.x0[1], self.x0[1]+self.tangent[1]*self.sz],'r-')
|
||||
if text: ax.text(self.x0[0]+0.5*self.tangent[0]*self.sz, self.x0[1]+0.5*self.tangent[1]*self.sz,self.numFace)
|
||||
|
||||
@property
|
||||
def center(self):
|
||||
return self.x0 + 0.5*self.tangent*self.sz
|
||||
|
||||
|
||||
class TreeNode(object):
|
||||
"""docstring for TreeNode"""
|
||||
def __init__(self, mesh, x0=[0,0], dim=2, depth=0, sz=[1,1], parent=None, fXm=None, fXp=None, fYm=None, fYp=None):
|
||||
children = None #:
|
||||
numCell = None
|
||||
|
||||
fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='x', dim=dim, sz=sz[1], depth=depth, parent=parent)
|
||||
fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] ], faceType='x', dim=dim, sz=sz[1], depth=depth, parent=parent)
|
||||
fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='y', dim=dim, sz=sz[0], depth=depth, parent=parent)
|
||||
fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1]], faceType='y', dim=dim, sz=sz[0], depth=depth, parent=parent)
|
||||
|
||||
self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp}
|
||||
def __init__(self, mesh, x0=[0,0], dim=2, depth=0, sz=[1,1], parent=None, fXm=None, fXp=None, fYm=None, fYp=None, fZm=None, fZp=None):
|
||||
|
||||
self.mesh = mesh
|
||||
self.x0 = np.array(x0, dtype=float)
|
||||
@@ -101,8 +104,22 @@ class TreeNode(object):
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.parent = parent
|
||||
self.children = None
|
||||
self.numCell = None
|
||||
if dim == 2:
|
||||
fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='x', dim=dim, sz=np.r_[sz[1]], depth=depth, parent=parent)
|
||||
fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] ], faceType='x', dim=dim, sz=np.r_[sz[1]], depth=depth, parent=parent)
|
||||
fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] ], faceType='y', dim=dim, sz=np.r_[sz[0]], depth=depth, parent=parent)
|
||||
fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1]], faceType='y', dim=dim, sz=np.r_[sz[0]], depth=depth, parent=parent)
|
||||
self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp}
|
||||
|
||||
elif dim == 3:
|
||||
fXm = fXm if fXm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='x', dim=dim, sz=np.r_[sz[1], sz[2]], depth=depth, parent=parent)
|
||||
fXp = fXp if fXp is not None else TreeFace(mesh, x0=np.r_[x0[0]+sz[0], x0[1] , x0[2] ], faceType='x', dim=dim, sz=np.r_[sz[1], sz[2]], depth=depth, parent=parent)
|
||||
fYm = fYm if fYm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='y', dim=dim, sz=np.r_[sz[0], sz[2]], depth=depth, parent=parent)
|
||||
fYp = fYp if fYp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1]+sz[1], x0[2] ], faceType='y', dim=dim, sz=np.r_[sz[0], sz[2]], depth=depth, parent=parent)
|
||||
fZm = fZm if fZm is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2] ], faceType='z', dim=dim, sz=np.r_[sz[0], sz[1]], depth=depth, parent=parent)
|
||||
fZp = fZp if fZp is not None else TreeFace(mesh, x0=np.r_[x0[0] , x0[1] , x0[2]+sz[2]], faceType='z', dim=dim, sz=np.r_[sz[0], sz[1]], depth=depth, parent=parent)
|
||||
self.faces = {"fXm":fXm, "fXp":fXp, "fYm":fYm, "fYp":fYp, "fZm":fZm, "fZp":fZp}
|
||||
|
||||
mesh.cells.add(self)
|
||||
|
||||
@property
|
||||
@@ -120,13 +137,17 @@ class TreeNode(object):
|
||||
return np.max([node.branchdepth for node in self.children.flatten('F')])
|
||||
|
||||
@property
|
||||
def gridCC(self): return self.x0 + 0.5*self.sz
|
||||
def center(self): return self.x0 + 0.5*self.sz
|
||||
|
||||
def refine(self, function=None):
|
||||
if self.dim == 2:
|
||||
return self._refine2D(function=function)
|
||||
|
||||
def _refine2D(self, function=None):
|
||||
if not self.isleaf and function is None: return
|
||||
|
||||
if function is not None:
|
||||
do = function(self.gridCC) > self.depth
|
||||
do = function(self.center) > self.depth
|
||||
if not do: return
|
||||
|
||||
self.mesh.isNumbered = False
|
||||
@@ -187,20 +208,32 @@ class TreeNode(object):
|
||||
if not self.isleaf: return
|
||||
x0, sz = self.x0, self.sz
|
||||
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=color, edgecolor='k'))
|
||||
if text: ax.text(self.gridCC[0],self.gridCC[1],self.numCell)
|
||||
if text: ax.text(self.center[0],self.center[1],self.numCell)
|
||||
|
||||
|
||||
|
||||
|
||||
class QuadTreeMesh(object):
|
||||
"""docstring for QuadTreeMesh"""
|
||||
def __init__(self, h, x0=[0,0]):
|
||||
self.faces = set()
|
||||
self.faceX = set()
|
||||
self.faceY = set()
|
||||
self.cells = set()
|
||||
self.x0 = np.array(x0,dtype=float)
|
||||
class TreeMesh(object):
|
||||
"""TreeMesh"""
|
||||
def __init__(self, h, x0=None):
|
||||
|
||||
assert type(h) is list, 'h must be a list'
|
||||
|
||||
self.h = h
|
||||
if x0 is None:
|
||||
x0 = np.zeros(self.dim)
|
||||
else:
|
||||
assert type(x0) in [list, np.ndarray], 'x0 must be a numpy array or a list'
|
||||
assert len(x0) == self.dim, 'x0 must have the same dimensions as the mesh'
|
||||
self.x0 = np.array(x0, dtype=float)
|
||||
|
||||
# set the sets for holding the faces and cells
|
||||
self.cells = set()
|
||||
self.faces = set()
|
||||
self.facesX = set()
|
||||
self.facesY = set()
|
||||
if self.dim == 3: self.facesZ = set()
|
||||
|
||||
self.children = np.empty([hi.size for hi in h],dtype=TreeNode)
|
||||
for i in range(h[0].size):
|
||||
for j in range(h[1].size):
|
||||
@@ -210,7 +243,7 @@ class QuadTreeMesh(object):
|
||||
x0j = (np.r_[x0[1], h[1][:j]]).sum()
|
||||
self.children[i][j] = TreeNode(self, x0=[x0i, x0j], dim=len(h), depth=0, sz=[h[0][i], h[1][j]], fXm=fXm, fYm=fYm)
|
||||
|
||||
isNumbered = utils.dependentProperty('_isNumbered', False, ['_faceDiv'], 'Setting this to False will delete all operators.')
|
||||
isNumbered = Utils.dependentProperty('_isNumbered', False, ['_faceDiv'], 'Setting this to False will delete all operators.')
|
||||
|
||||
@property
|
||||
def branchdepth(self):
|
||||
@@ -223,41 +256,81 @@ class QuadTreeMesh(object):
|
||||
def number(self):
|
||||
if self.isNumbered: return
|
||||
|
||||
sortedCells = sorted(M.cells,key=SortByX0())
|
||||
sortedFaceX = sorted(M.faceX,key=SortByX0())
|
||||
sortedFaceY = sorted(M.faceY,key=SortByX0())
|
||||
nFx = len(sortedFaceX)
|
||||
for i, sc in enumerate(sortedCells): sc.numCell = i
|
||||
for i, sfx in enumerate(sortedFaceX): sfx.numFace = i
|
||||
for i, sfy in enumerate(sortedFaceY): sfy.numFace = i + nFx
|
||||
self.sortedCells = sorted(self.cells,key=SortByX0())
|
||||
for i, sc in enumerate(self.sortedCells): sc.numCell = i
|
||||
|
||||
self.sortedCells = sortedCells
|
||||
self.sortedFaceX = sortedFaceX
|
||||
self.sortedFaceY = sortedFaceY
|
||||
self.sortedFaceX = sorted(self.facesX,key=SortByX0())
|
||||
for i, sfx in enumerate(self.sortedFaceX): sfx.numFace = i
|
||||
|
||||
self.sortedFaceY = sorted(self.facesY,key=SortByX0())
|
||||
for i, sfy in enumerate(self.sortedFaceY): sfy.numFace = i + self.nFx
|
||||
|
||||
if self.dim == 3:
|
||||
self.sortedFaceZ = sorted(self.facesZ,key=SortByX0())
|
||||
for i, sfz in enumerate(self.sortedFaceZ): sfz.numFace = i + self.nFx + self.nFy
|
||||
|
||||
self.isNumbered = True
|
||||
|
||||
@property
|
||||
def dim(self): return len(self.h)
|
||||
|
||||
@property
|
||||
def dim(self): return len(self.x0)
|
||||
@property
|
||||
def nC(self): return len(self.cells)
|
||||
|
||||
@property
|
||||
def nF(self): return len(self.faces)
|
||||
|
||||
@property
|
||||
def nFx(self): return len(self.facesX)
|
||||
|
||||
@property
|
||||
def nFy(self): return len(self.facesY)
|
||||
|
||||
@property
|
||||
def nFz(self): return len(self.facesZ)
|
||||
|
||||
@property
|
||||
def nE(self): return len(self.faces)
|
||||
|
||||
@property
|
||||
def nEx(self):
|
||||
if self.dim == 2:
|
||||
return len(self.facesY)
|
||||
else: raise NotImplementedError('nEx')
|
||||
|
||||
@property
|
||||
def nEy(self):
|
||||
if self.dim == 2:
|
||||
return len(self.facesX)
|
||||
else: raise NotImplementedError('nEy')
|
||||
|
||||
@property
|
||||
def gridCC(self):
|
||||
if getattr(self, '_gridCC', None) is None:
|
||||
self.number()
|
||||
self._gridCC = np.empty((self.nC,self.dim))
|
||||
for ii, cell in enumerate(self.sortedCells):
|
||||
self._gridCC[ii,:] = cell.gridCC
|
||||
self._gridCC[ii,:] = cell.center
|
||||
return self._gridCC
|
||||
|
||||
@property
|
||||
def gridFx(self):
|
||||
if getattr(self, '_gridFx', None) is None:
|
||||
self.number()
|
||||
self._gridFx = np.empty((self.nFx,self.dim))
|
||||
for ii, face in enumerate(self.sortedFaceX):
|
||||
self._gridFx[ii,:] = face.center
|
||||
return self._gridFx
|
||||
|
||||
@property
|
||||
def gridFy(self):
|
||||
if getattr(self, '_gridFy', None) is None:
|
||||
self.number()
|
||||
self._gridFy = np.empty((self.nFy,self.dim))
|
||||
for ii, face in enumerate(self.sortedFaceY):
|
||||
self._gridFy[ii,:] = face.center
|
||||
return self._gridFy
|
||||
|
||||
@property
|
||||
def vol(self):
|
||||
self.number()
|
||||
@@ -280,20 +353,27 @@ class QuadTreeMesh(object):
|
||||
VOL = self.vol
|
||||
D = sp.csr_matrix((V,(I,J)), shape=(M.nC, M.nF))
|
||||
S = self.area
|
||||
self._faceDiv = utils.sdiag(1/VOL)*D*utils.sdiag(S)
|
||||
self._faceDiv = Utils.sdiag(1/VOL)*D*Utils.sdiag(S)
|
||||
return self._faceDiv
|
||||
|
||||
|
||||
def plotGrid(self, ax=None, text=True, plotC=True, plotF=False):
|
||||
def plotGrid(self, ax=None, text=True, plotC=True, plotF=False, showIt=False):
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
|
||||
if plotC: [node.viz(ax, text=text) for node in self.cells]
|
||||
if plotF: [node.viz(ax, text=text) for node in self.faces]
|
||||
ax.set_xlim((self.x0[0], self.h[0].sum()))
|
||||
ax.set_ylim((self.x0[1], self.h[1].sum()))
|
||||
if showIt: plt.show()
|
||||
|
||||
|
||||
def plotImage(self, I, ax=None):
|
||||
def plotImage(self, I, ax=None, showIt=True):
|
||||
if self.dim == 2:
|
||||
self._plotImage2D(I, ax=ax, showIt=showIt)
|
||||
elif self.dim == 3:
|
||||
raise NotImplementedError('3D visualization is not yet implemented.')
|
||||
|
||||
def _plotImage2D(self, I, ax=None, showIt=True):
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
jet = cm = plt.get_cmap('jet')
|
||||
cNorm = colors.Normalize(vmin=I.min(), vmax=I.max())
|
||||
@@ -304,11 +384,12 @@ class QuadTreeMesh(object):
|
||||
node.viz(ax=ax, color=scalarMap.to_rgba(I[ii]))
|
||||
scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots
|
||||
plt.colorbar(scalarMap)
|
||||
if showIt: plt.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
M = QuadTreeMesh([np.ones(x) for x in [4,10]])
|
||||
M = TreeMesh([np.ones(x) for x in [4,10]])
|
||||
|
||||
def function(xc):
|
||||
r = xc - np.r_[2.,6.]
|
||||
@@ -330,7 +411,7 @@ if __name__ == '__main__':
|
||||
q = np.zeros(M.nC)
|
||||
q[208] = -1.0
|
||||
q[291] = 1.0
|
||||
b = utils.Solver(-DIV*DIV.T).solve(q)
|
||||
b = Solver(-DIV*DIV.T).solve(q)
|
||||
plt.figure()
|
||||
M.plotImage(b)
|
||||
# plt.gca().invert_yaxis()
|
||||
@@ -1,6 +1,6 @@
|
||||
from Cyl1DMesh import Cyl1DMesh
|
||||
from TensorMesh import TensorMesh
|
||||
from QuadTreeMesh import QuadTreeMesh
|
||||
from TreeMesh import TreeMesh
|
||||
from LogicallyOrthogonalMesh import LogicallyOrthogonalMesh
|
||||
from BaseMesh import BaseMesh
|
||||
from TensorView import TensorView
|
||||
|
||||
@@ -1,19 +1,41 @@
|
||||
from SimPEG import mesh, np
|
||||
from SimPEG import Mesh, np
|
||||
import unittest
|
||||
|
||||
|
||||
|
||||
class TestCheckDerivative(unittest.TestCase):
|
||||
class TestQuadTreeMesh(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
M = mesh.QuadTreeMesh([3,2],[1,2])
|
||||
M = Mesh.TreeMesh([np.ones(x) for x in [3,2]])
|
||||
for ii in range(1):
|
||||
M.children[ii,ii].refine()
|
||||
self.M = M
|
||||
M.number()
|
||||
# M.plotGrid(showIt=True)
|
||||
|
||||
def test_MeshSizes(self):
|
||||
self.assertTrue(len(self.M.faces)==25)
|
||||
self.assertTrue(len(self.M.cells)==9)
|
||||
self.assertTrue(self.M.nC==9)
|
||||
self.assertTrue(self.M.nF==25)
|
||||
self.assertTrue(self.M.nFx==12)
|
||||
self.assertTrue(self.M.nFy==13)
|
||||
self.assertTrue(self.M.nE==25)
|
||||
self.assertTrue(self.M.nEx==13)
|
||||
self.assertTrue(self.M.nEy==12)
|
||||
|
||||
def test_gridCC(self):
|
||||
x = np.r_[0.25,0.75,1.5,2.5,0.25,0.75,0.5,1.5,2.5]
|
||||
y = np.r_[0.25,0.25,0.5,0.5,0.75,0.75,1.5,1.5,1.5]
|
||||
self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridCC).flatten()) == 0)
|
||||
|
||||
def test_gridFx(self):
|
||||
x = np.r_[0.0,0.5,1.0,2.0,3.0,0.0,0.5,1.0,0.0,1.0,2.0,3.0]
|
||||
y = np.r_[0.25,0.25,0.25,0.5,0.5,0.75,0.75,0.75,1.5,1.5,1.5,1.5]
|
||||
self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridFx).flatten()) == 0)
|
||||
|
||||
def test_gridFy(self):
|
||||
x = np.r_[0.25,0.75,1.5,2.5,0.25,0.75,0.25,0.75,1.5,2.5,0.5,1.5,2.5]
|
||||
y = np.r_[0,0,0,0,0.5,0.5,1,1,1,1,2,2,2]
|
||||
self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridFy).flatten()) == 0)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user