PlotSlice in OcTree

This commit is contained in:
Rowan Cockett
2015-11-18 11:52:05 -08:00
parent 1718e3506d
commit c1b5f45ac7
3 changed files with 181 additions and 41 deletions
+20 -20
View File
@@ -413,34 +413,34 @@ class TensorMesh(BaseTensorMesh, TensorView, DiffOperators, InnerProducts):
break
if n == 1:
outStr = outStr + ' {0:.2f},'.format(h)
outStr += ' {0:.2f},'.format(h)
else:
outStr = outStr + ' {0:d}*{1:.2f},'.format(n,h)
outStr += ' {0:d}*{1:.2f},'.format(n,h)
return outStr[:-1]
if self.dim == 1:
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
outStr = outStr + printH(self.hx, outStr='\n hx:')
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n nCx: {0:d}'.format(self.nCx)
outStr += printH(self.hx, outStr='\n hx:')
pass
elif self.dim == 2:
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1])
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
outStr = outStr + '\n nCy: {0:d}'.format(self.nCy)
outStr = outStr + printH(self.hx, outStr='\n hx:')
outStr = outStr + printH(self.hy, outStr='\n hy:')
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += '\n nCx: {0:d}'.format(self.nCx)
outStr += '\n nCy: {0:d}'.format(self.nCy)
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
elif self.dim == 3:
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1])
outStr = outStr + '\n z0: {0:.2f}'.format(self.x0[2])
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
outStr = outStr + '\n nCy: {0:d}'.format(self.nCy)
outStr = outStr + '\n nCz: {0:d}'.format(self.nCz)
outStr = outStr + printH(self.hx, outStr='\n hx:')
outStr = outStr + printH(self.hy, outStr='\n hy:')
outStr = outStr + printH(self.hz, outStr='\n hz:')
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
outStr += '\n nCx: {0:d}'.format(self.nCx)
outStr += '\n nCy: {0:d}'.format(self.nCy)
outStr += '\n nCz: {0:d}'.format(self.nCz)
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
outStr += printH(self.hz, outStr='\n hz:')
return outStr
+160 -21
View File
@@ -99,6 +99,7 @@ import matplotlib.cm as cmx
import TreeUtils
from InnerProducts import InnerProducts
from BaseMesh import BaseMesh
from TensorMesh import TensorMesh
import time
MAX_BITS = 20
@@ -107,7 +108,7 @@ class TreeMesh(BaseMesh, InnerProducts):
_meshType = 'TREE'
def __init__(self, h_in, x0_in=None, levels=3):
def __init__(self, h_in, x0_in=None, levels=None):
assert type(h_in) is list, 'h_in must be a list'
assert len(h_in) in [2,3], "There is only support for TreeMesh in 2D or 3D."
@@ -120,6 +121,7 @@ class TreeMesh(BaseMesh, InnerProducts):
h_i = Utils.meshTensor(h_i)
assert isinstance(h_i, np.ndarray), ("h[%i] is not a numpy array." % i)
assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i)
if levels is None:levels = int(np.log2(len(h_i)))
assert len(h_i) == 2**levels, "must make h and levels match"
h[i] = h_i[:] # make a copy.
self._h = h
@@ -184,6 +186,58 @@ class TreeMesh(BaseMesh, InnerProducts):
@property
def levels(self): return self._levels
@property
def fill(self):
return float(self.nC)/((2**self.maxLevel)**self.dim)
@property
def maxLevel(self):
l = 0
for cell in self._cells:
p = self._pointer(cell)
l = max(l,p[-1])
return l
def __str__(self):
outStr = ' ---- %sTreeMesh ---- '%('Oc' if self.dim == 3 else 'Quad')
def printH(hx, outStr=''):
i = -1
while True:
i = i + 1
if i > hx.size:
break
elif i == hx.size:
break
h = hx[i]
n = 1
for j in range(i+1, hx.size):
if hx[j] == h:
n = n + 1
i = i + 1
else:
break
if n == 1:
outStr += ' {0:.2f},'.format(h)
else:
outStr += ' {0:d}*{1:.2f},'.format(n,h)
return outStr[:-1]
if self.dim == 2:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
elif self.dim == 3:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
outStr += printH(self.hz, outStr='\n hz:')
outStr += '\n nC: {0:d}'.format(self.nC)
outStr += '\n Fill: %2.2f%%'%(self.fill*100)
return outStr
@property
def h(self):
"""h is a list containing the cell widths of the tensor mesh in each dimension."""
@@ -2109,8 +2163,8 @@ class TreeMesh(BaseMesh, InnerProducts):
if showIt:plt.show()
def plotImage(self, I, ax=None, showIt=True):
if self.dim == 3: raise Exception()
def plotImage(self, I, ax=None, showIt=True, grid=False):
if self.dim == 3: raise Exception('Use plot slice?')
if ax is None: ax = plt.subplot(111)
jet = cm = plt.get_cmap('jet')
@@ -2120,12 +2174,102 @@ class TreeMesh(BaseMesh, InnerProducts):
ax.set_ylim((self.x0[1], self.h[1].sum()))
for ii, node in enumerate(self._sortedCells):
x0, sz = self._cellN(node), self._cellH(node)
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k'))
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k' if grid else 'none'))
# if text: ax.text(self.center[0],self.center[1],self.num)
scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots
plt.colorbar(scalarMap)
if showIt: plt.show()
def plotSlice(self, v, vType='CC',
normal='Z', ind=None, grid=True, view='real',
ax=None, clim=None, showIt=False,
pcolorOpts={},
streamOpts={'color':'k'},
gridOpts={'color':'k'}):
assert vType in ['CC','F','E']
assert self.dim == 3
szSliceDim = len(getattr(self, 'h'+normal.lower())) #: Size of the sliced dimension
if ind is None: ind = int(szSliceDim/2)
assert type(ind) in [int, long], 'ind must be an integer'
indLoc = getattr(self,'vectorCC'+normal.lower())[ind]
normalInd = {'X':0,'Y':1,'Z':2}[normal]
antiNormalInd = {'X':[1,2],'Y':[0,2],'Z':[0,1]}[normal]
h2d = []
x2d = []
if 'X' not in normal:
h2d.append(self.hx)
x2d.append(self.x0[0])
if 'Y' not in normal:
h2d.append(self.hy)
x2d.append(self.x0[1])
if 'Z' not in normal:
h2d.append(self.hz)
x2d.append(self.x0[2])
tM = TensorMesh(h2d, x2d) #: Temp Mesh
def getLocs(*args):
if len(args) == 1:
grids = (args[0],args[0],args[0])
else:
assert len(args) == 3
grids = args
one = np.ones((grids[0].shape[0],1))*indLoc
if normal == 'X':
return np.hstack((one, grids[0][:,[0]], grids[1][:,[1]]))
if normal == 'Y':
return np.hstack((grids[0][:,[0]], one, grids[1][:,[1]]))
if normal == 'Z':
return np.hstack((grids[0][:,[0]], grids[1][:,[1]], one))
def doSlice(v):
if vType == 'CC':
P = self.getInterpolationMat(getLocs(tM.gridCC),'CC')
elif vType in ['F', 'E']:
Ps = []
gridX = getLocs(getattr(tM, 'grid' + vType + 'x'))
gridY = getLocs(getattr(tM, 'grid' + vType + 'y'))
Ps += [self.getInterpolationMat(gridX,vType + ('y' if normal == 'X' else 'x'))]
Ps += [self.getInterpolationMat(gridY,vType + ('y' if normal == 'Z' else 'z'))]
P = sp.vstack(Ps)
return P*v
v2d = doSlice(v)
if ax is None:
fig = plt.figure()
ax = plt.subplot(111)
else:
assert isinstance(ax, matplotlib.axes.Axes), "ax must be an matplotlib.axes.Axes"
fig = ax.figure
out = tM._plotImage2D(v2d, vType=vType, view=view,
ax=ax, clim=clim,
pcolorOpts=pcolorOpts, streamOpts=streamOpts)
ax.set_xlabel('y' if normal == 'X' else 'x')
ax.set_ylabel('y' if normal == 'Z' else 'z')
ax.set_title('Slice %d, %s = %4.2f' % (ind,normal,indLoc))
if grid:
_ = antiNormalInd
X = []
Y = []
for cell in self._cells:
p = self._pointer(cell)
n, h = self._cellN(p), self._cellH(p)
if n[normalInd]<indLoc and n[normalInd]+h[normalInd]>indLoc:
X += [n[_[0]] , n[_[0]] + h[_[0]], n[_[0]] + h[_[0]], n[_[0]] , n[_[0]], np.nan]
Y += [n[_[1]] , n[_[1]] , n[_[1]] + h[_[1]], n[_[1]] + h[_[1]], n[_[1]], np.nan]
out = list(out)
out += ax.plot(X,Y, **gridOpts)
if len(out) > 2: # this is not robust, searching for the streamlines would be better
out[1].lines.set_zorder(200)
out[1].arrows.set_zorder(201)
if showIt: plt.show()
return tuple(out)
class Cell(object):
def __init__(self, mesh, index, pointer):
self.mesh = mesh
@@ -2192,27 +2336,24 @@ if __name__ == '__main__':
def function(cell):
r = cell.center - np.array([0.5]*len(cell.center))
dist1 = np.sqrt(r.dot(r)) - 0.08
dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
dist = np.sqrt(r.dot(r))
# dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
dist = min([dist1,dist2])
# dist = min([dist1,dist2])
# if dist < 0.05:
# return 5
if dist < 0.05:
return 6
if dist < 0.2:
if dist < 0.1:
return 5
if dist < 0.3:
if dist < 0.2:
return 4
if dist < 1.0:
if dist < 0.4:
return 3
else:
return 0
return 2
# T = TreeMesh([[(1,128)],[(1,128)],[(1,128)]],levels=7)
# T = TreeMesh([128,128,128],levels=7)
T = TreeMesh([128,128,128])
# T = TreeMesh([64,64],levels=6)
T = TreeMesh([4,4,4],levels=2)
# T = TreeMesh([4,4,4],levels=2)
# T = TreeMesh([[(1,128)],[(1,128)]],levels=7)
# T.refine(lambda xc:2, balance=False)
# T._index([0,0,0])
@@ -2220,14 +2361,12 @@ if __name__ == '__main__':
# tic = time.time()
# T.refine(function)#, balance=False)
T.refine(function)#, balance=False)
# print time.time() - tic
# print T.nC
T.plotSlice(np.log(T.vol))#np.random.rand(T.nC))
print T.nC
P = T.getInterpolationMat([0.2,0,0], 'Ex')
print P.todense()
plt.show()
blah
# T.plotImage(np.arange(len(T.vol)),showIt=True)
+1
View File
@@ -216,6 +216,7 @@ class TensorView(object):
if ind is None: ind = int(szSliceDim/2)
assert type(ind) in [int, long], 'ind must be an integer'
assert not (v.dtype == complex and view == 'vec'), 'Can not plot a complex vector.'
# The slicing and plotting code!!
def getIndSlice(v):