From c1b5f45ac79e523bd7fcb8d00f43b98717b1151f Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Wed, 18 Nov 2015 11:52:05 -0800 Subject: [PATCH] PlotSlice in OcTree --- SimPEG/Mesh/TensorMesh.py | 40 ++++----- SimPEG/Mesh/TreeMesh.py | 181 +++++++++++++++++++++++++++++++++----- SimPEG/Mesh/View.py | 1 + 3 files changed, 181 insertions(+), 41 deletions(-) diff --git a/SimPEG/Mesh/TensorMesh.py b/SimPEG/Mesh/TensorMesh.py index 37bfea86..6b0e6e65 100644 --- a/SimPEG/Mesh/TensorMesh.py +++ b/SimPEG/Mesh/TensorMesh.py @@ -413,34 +413,34 @@ class TensorMesh(BaseTensorMesh, TensorView, DiffOperators, InnerProducts): break if n == 1: - outStr = outStr + ' {0:.2f},'.format(h) + outStr += ' {0:.2f},'.format(h) else: - outStr = outStr + ' {0:d}*{1:.2f},'.format(n,h) + outStr += ' {0:d}*{1:.2f},'.format(n,h) return outStr[:-1] if self.dim == 1: - outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0]) - outStr = outStr + '\n nCx: {0:d}'.format(self.nCx) - outStr = outStr + printH(self.hx, outStr='\n hx:') + outStr += '\n x0: {0:.2f}'.format(self.x0[0]) + outStr += '\n nCx: {0:d}'.format(self.nCx) + outStr += printH(self.hx, outStr='\n hx:') pass elif self.dim == 2: - outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0]) - outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1]) - outStr = outStr + '\n nCx: {0:d}'.format(self.nCx) - outStr = outStr + '\n nCy: {0:d}'.format(self.nCy) - outStr = outStr + printH(self.hx, outStr='\n hx:') - outStr = outStr + printH(self.hy, outStr='\n hy:') + outStr += '\n x0: {0:.2f}'.format(self.x0[0]) + outStr += '\n y0: {0:.2f}'.format(self.x0[1]) + outStr += '\n nCx: {0:d}'.format(self.nCx) + outStr += '\n nCy: {0:d}'.format(self.nCy) + outStr += printH(self.hx, outStr='\n hx:') + outStr += printH(self.hy, outStr='\n hy:') elif self.dim == 3: - outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0]) - outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1]) - outStr = outStr + '\n z0: {0:.2f}'.format(self.x0[2]) - outStr = outStr + '\n nCx: {0:d}'.format(self.nCx) - outStr = outStr + '\n nCy: {0:d}'.format(self.nCy) - outStr = outStr + '\n nCz: {0:d}'.format(self.nCz) - outStr = outStr + printH(self.hx, outStr='\n hx:') - outStr = outStr + printH(self.hy, outStr='\n hy:') - outStr = outStr + printH(self.hz, outStr='\n hz:') + outStr += '\n x0: {0:.2f}'.format(self.x0[0]) + outStr += '\n y0: {0:.2f}'.format(self.x0[1]) + outStr += '\n z0: {0:.2f}'.format(self.x0[2]) + outStr += '\n nCx: {0:d}'.format(self.nCx) + outStr += '\n nCy: {0:d}'.format(self.nCy) + outStr += '\n nCz: {0:d}'.format(self.nCz) + outStr += printH(self.hx, outStr='\n hx:') + outStr += printH(self.hy, outStr='\n hy:') + outStr += printH(self.hz, outStr='\n hz:') return outStr diff --git a/SimPEG/Mesh/TreeMesh.py b/SimPEG/Mesh/TreeMesh.py index 4f4e98d3..55b5101e 100644 --- a/SimPEG/Mesh/TreeMesh.py +++ b/SimPEG/Mesh/TreeMesh.py @@ -99,6 +99,7 @@ import matplotlib.cm as cmx import TreeUtils from InnerProducts import InnerProducts from BaseMesh import BaseMesh +from TensorMesh import TensorMesh import time MAX_BITS = 20 @@ -107,7 +108,7 @@ class TreeMesh(BaseMesh, InnerProducts): _meshType = 'TREE' - def __init__(self, h_in, x0_in=None, levels=3): + def __init__(self, h_in, x0_in=None, levels=None): assert type(h_in) is list, 'h_in must be a list' assert len(h_in) in [2,3], "There is only support for TreeMesh in 2D or 3D." @@ -120,6 +121,7 @@ class TreeMesh(BaseMesh, InnerProducts): h_i = Utils.meshTensor(h_i) assert isinstance(h_i, np.ndarray), ("h[%i] is not a numpy array." % i) assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i) + if levels is None:levels = int(np.log2(len(h_i))) assert len(h_i) == 2**levels, "must make h and levels match" h[i] = h_i[:] # make a copy. self._h = h @@ -184,6 +186,58 @@ class TreeMesh(BaseMesh, InnerProducts): @property def levels(self): return self._levels + @property + def fill(self): + return float(self.nC)/((2**self.maxLevel)**self.dim) + + @property + def maxLevel(self): + l = 0 + for cell in self._cells: + p = self._pointer(cell) + l = max(l,p[-1]) + return l + + def __str__(self): + outStr = ' ---- %sTreeMesh ---- '%('Oc' if self.dim == 3 else 'Quad') + def printH(hx, outStr=''): + i = -1 + while True: + i = i + 1 + if i > hx.size: + break + elif i == hx.size: + break + h = hx[i] + n = 1 + for j in range(i+1, hx.size): + if hx[j] == h: + n = n + 1 + i = i + 1 + else: + break + if n == 1: + outStr += ' {0:.2f},'.format(h) + else: + outStr += ' {0:d}*{1:.2f},'.format(n,h) + return outStr[:-1] + + if self.dim == 2: + outStr += '\n x0: {0:.2f}'.format(self.x0[0]) + outStr += '\n y0: {0:.2f}'.format(self.x0[1]) + outStr += printH(self.hx, outStr='\n hx:') + outStr += printH(self.hy, outStr='\n hy:') + elif self.dim == 3: + outStr += '\n x0: {0:.2f}'.format(self.x0[0]) + outStr += '\n y0: {0:.2f}'.format(self.x0[1]) + outStr += '\n z0: {0:.2f}'.format(self.x0[2]) + outStr += printH(self.hx, outStr='\n hx:') + outStr += printH(self.hy, outStr='\n hy:') + outStr += printH(self.hz, outStr='\n hz:') + outStr += '\n nC: {0:d}'.format(self.nC) + outStr += '\n Fill: %2.2f%%'%(self.fill*100) + return outStr + @property def h(self): """h is a list containing the cell widths of the tensor mesh in each dimension.""" @@ -2109,8 +2163,8 @@ class TreeMesh(BaseMesh, InnerProducts): if showIt:plt.show() - def plotImage(self, I, ax=None, showIt=True): - if self.dim == 3: raise Exception() + def plotImage(self, I, ax=None, showIt=True, grid=False): + if self.dim == 3: raise Exception('Use plot slice?') if ax is None: ax = plt.subplot(111) jet = cm = plt.get_cmap('jet') @@ -2120,12 +2174,102 @@ class TreeMesh(BaseMesh, InnerProducts): ax.set_ylim((self.x0[1], self.h[1].sum())) for ii, node in enumerate(self._sortedCells): x0, sz = self._cellN(node), self._cellH(node) - ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k')) + ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k' if grid else 'none')) # if text: ax.text(self.center[0],self.center[1],self.num) scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots plt.colorbar(scalarMap) if showIt: plt.show() + def plotSlice(self, v, vType='CC', + normal='Z', ind=None, grid=True, view='real', + ax=None, clim=None, showIt=False, + pcolorOpts={}, + streamOpts={'color':'k'}, + gridOpts={'color':'k'}): + + assert vType in ['CC','F','E'] + assert self.dim == 3 + + szSliceDim = len(getattr(self, 'h'+normal.lower())) #: Size of the sliced dimension + if ind is None: ind = int(szSliceDim/2) + assert type(ind) in [int, long], 'ind must be an integer' + indLoc = getattr(self,'vectorCC'+normal.lower())[ind] + normalInd = {'X':0,'Y':1,'Z':2}[normal] + antiNormalInd = {'X':[1,2],'Y':[0,2],'Z':[0,1]}[normal] + h2d = [] + x2d = [] + if 'X' not in normal: + h2d.append(self.hx) + x2d.append(self.x0[0]) + if 'Y' not in normal: + h2d.append(self.hy) + x2d.append(self.x0[1]) + if 'Z' not in normal: + h2d.append(self.hz) + x2d.append(self.x0[2]) + tM = TensorMesh(h2d, x2d) #: Temp Mesh + + def getLocs(*args): + if len(args) == 1: + grids = (args[0],args[0],args[0]) + else: + assert len(args) == 3 + grids = args + one = np.ones((grids[0].shape[0],1))*indLoc + if normal == 'X': + return np.hstack((one, grids[0][:,[0]], grids[1][:,[1]])) + if normal == 'Y': + return np.hstack((grids[0][:,[0]], one, grids[1][:,[1]])) + if normal == 'Z': + return np.hstack((grids[0][:,[0]], grids[1][:,[1]], one)) + def doSlice(v): + if vType == 'CC': + P = self.getInterpolationMat(getLocs(tM.gridCC),'CC') + elif vType in ['F', 'E']: + Ps = [] + gridX = getLocs(getattr(tM, 'grid' + vType + 'x')) + gridY = getLocs(getattr(tM, 'grid' + vType + 'y')) + Ps += [self.getInterpolationMat(gridX,vType + ('y' if normal == 'X' else 'x'))] + Ps += [self.getInterpolationMat(gridY,vType + ('y' if normal == 'Z' else 'z'))] + P = sp.vstack(Ps) + return P*v + + v2d = doSlice(v) + + if ax is None: + fig = plt.figure() + ax = plt.subplot(111) + else: + assert isinstance(ax, matplotlib.axes.Axes), "ax must be an matplotlib.axes.Axes" + fig = ax.figure + + out = tM._plotImage2D(v2d, vType=vType, view=view, + ax=ax, clim=clim, + pcolorOpts=pcolorOpts, streamOpts=streamOpts) + + ax.set_xlabel('y' if normal == 'X' else 'x') + ax.set_ylabel('y' if normal == 'Z' else 'z') + ax.set_title('Slice %d, %s = %4.2f' % (ind,normal,indLoc)) + + if grid: + _ = antiNormalInd + X = [] + Y = [] + for cell in self._cells: + p = self._pointer(cell) + n, h = self._cellN(p), self._cellH(p) + if n[normalInd]indLoc: + X += [n[_[0]] , n[_[0]] + h[_[0]], n[_[0]] + h[_[0]], n[_[0]] , n[_[0]], np.nan] + Y += [n[_[1]] , n[_[1]] , n[_[1]] + h[_[1]], n[_[1]] + h[_[1]], n[_[1]], np.nan] + out = list(out) + out += ax.plot(X,Y, **gridOpts) + if len(out) > 2: # this is not robust, searching for the streamlines would be better + out[1].lines.set_zorder(200) + out[1].arrows.set_zorder(201) + if showIt: plt.show() + return tuple(out) + + class Cell(object): def __init__(self, mesh, index, pointer): self.mesh = mesh @@ -2192,27 +2336,24 @@ if __name__ == '__main__': def function(cell): r = cell.center - np.array([0.5]*len(cell.center)) - dist1 = np.sqrt(r.dot(r)) - 0.08 - dist2 = np.abs(cell.center[-1] - topo(cell.center[0])) + dist = np.sqrt(r.dot(r)) + # dist2 = np.abs(cell.center[-1] - topo(cell.center[0])) - dist = min([dist1,dist2]) + # dist = min([dist1,dist2]) # if dist < 0.05: # return 5 - if dist < 0.05: - return 6 - if dist < 0.2: + if dist < 0.1: return 5 - if dist < 0.3: + if dist < 0.2: return 4 - if dist < 1.0: + if dist < 0.4: return 3 - else: - return 0 + return 2 # T = TreeMesh([[(1,128)],[(1,128)],[(1,128)]],levels=7) - # T = TreeMesh([128,128,128],levels=7) + T = TreeMesh([128,128,128]) # T = TreeMesh([64,64],levels=6) - T = TreeMesh([4,4,4],levels=2) + # T = TreeMesh([4,4,4],levels=2) # T = TreeMesh([[(1,128)],[(1,128)]],levels=7) # T.refine(lambda xc:2, balance=False) # T._index([0,0,0]) @@ -2220,14 +2361,12 @@ if __name__ == '__main__': # tic = time.time() - # T.refine(function)#, balance=False) + T.refine(function)#, balance=False) # print time.time() - tic # print T.nC + T.plotSlice(np.log(T.vol))#np.random.rand(T.nC)) - print T.nC - - P = T.getInterpolationMat([0.2,0,0], 'Ex') - print P.todense() + plt.show() blah # T.plotImage(np.arange(len(T.vol)),showIt=True) diff --git a/SimPEG/Mesh/View.py b/SimPEG/Mesh/View.py index 38432d3c..272a2f47 100644 --- a/SimPEG/Mesh/View.py +++ b/SimPEG/Mesh/View.py @@ -216,6 +216,7 @@ class TensorView(object): if ind is None: ind = int(szSliceDim/2) assert type(ind) in [int, long], 'ind must be an integer' + assert not (v.dtype == complex and view == 'vec'), 'Can not plot a complex vector.' # The slicing and plotting code!! def getIndSlice(v):