mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 19:48:52 +08:00
PlotSlice in OcTree
This commit is contained in:
+20
-20
@@ -413,34 +413,34 @@ class TensorMesh(BaseTensorMesh, TensorView, DiffOperators, InnerProducts):
|
||||
break
|
||||
|
||||
if n == 1:
|
||||
outStr = outStr + ' {0:.2f},'.format(h)
|
||||
outStr += ' {0:.2f},'.format(h)
|
||||
else:
|
||||
outStr = outStr + ' {0:d}*{1:.2f},'.format(n,h)
|
||||
outStr += ' {0:d}*{1:.2f},'.format(n,h)
|
||||
|
||||
return outStr[:-1]
|
||||
|
||||
if self.dim == 1:
|
||||
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr = outStr + printH(self.hx, outStr='\n hx:')
|
||||
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr += '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr += printH(self.hx, outStr='\n hx:')
|
||||
pass
|
||||
elif self.dim == 2:
|
||||
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr = outStr + '\n nCy: {0:d}'.format(self.nCy)
|
||||
outStr = outStr + printH(self.hx, outStr='\n hx:')
|
||||
outStr = outStr + printH(self.hy, outStr='\n hy:')
|
||||
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr += '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr += '\n nCy: {0:d}'.format(self.nCy)
|
||||
outStr += printH(self.hx, outStr='\n hx:')
|
||||
outStr += printH(self.hy, outStr='\n hy:')
|
||||
elif self.dim == 3:
|
||||
outStr = outStr + '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr = outStr + '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr = outStr + '\n z0: {0:.2f}'.format(self.x0[2])
|
||||
outStr = outStr + '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr = outStr + '\n nCy: {0:d}'.format(self.nCy)
|
||||
outStr = outStr + '\n nCz: {0:d}'.format(self.nCz)
|
||||
outStr = outStr + printH(self.hx, outStr='\n hx:')
|
||||
outStr = outStr + printH(self.hy, outStr='\n hy:')
|
||||
outStr = outStr + printH(self.hz, outStr='\n hz:')
|
||||
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
|
||||
outStr += '\n nCx: {0:d}'.format(self.nCx)
|
||||
outStr += '\n nCy: {0:d}'.format(self.nCy)
|
||||
outStr += '\n nCz: {0:d}'.format(self.nCz)
|
||||
outStr += printH(self.hx, outStr='\n hx:')
|
||||
outStr += printH(self.hy, outStr='\n hy:')
|
||||
outStr += printH(self.hz, outStr='\n hz:')
|
||||
|
||||
return outStr
|
||||
|
||||
|
||||
+160
-21
@@ -99,6 +99,7 @@ import matplotlib.cm as cmx
|
||||
import TreeUtils
|
||||
from InnerProducts import InnerProducts
|
||||
from BaseMesh import BaseMesh
|
||||
from TensorMesh import TensorMesh
|
||||
import time
|
||||
|
||||
MAX_BITS = 20
|
||||
@@ -107,7 +108,7 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
|
||||
_meshType = 'TREE'
|
||||
|
||||
def __init__(self, h_in, x0_in=None, levels=3):
|
||||
def __init__(self, h_in, x0_in=None, levels=None):
|
||||
assert type(h_in) is list, 'h_in must be a list'
|
||||
assert len(h_in) in [2,3], "There is only support for TreeMesh in 2D or 3D."
|
||||
|
||||
@@ -120,6 +121,7 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
h_i = Utils.meshTensor(h_i)
|
||||
assert isinstance(h_i, np.ndarray), ("h[%i] is not a numpy array." % i)
|
||||
assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i)
|
||||
if levels is None:levels = int(np.log2(len(h_i)))
|
||||
assert len(h_i) == 2**levels, "must make h and levels match"
|
||||
h[i] = h_i[:] # make a copy.
|
||||
self._h = h
|
||||
@@ -184,6 +186,58 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
@property
|
||||
def levels(self): return self._levels
|
||||
|
||||
@property
|
||||
def fill(self):
|
||||
return float(self.nC)/((2**self.maxLevel)**self.dim)
|
||||
|
||||
@property
|
||||
def maxLevel(self):
|
||||
l = 0
|
||||
for cell in self._cells:
|
||||
p = self._pointer(cell)
|
||||
l = max(l,p[-1])
|
||||
return l
|
||||
|
||||
def __str__(self):
|
||||
outStr = ' ---- %sTreeMesh ---- '%('Oc' if self.dim == 3 else 'Quad')
|
||||
def printH(hx, outStr=''):
|
||||
i = -1
|
||||
while True:
|
||||
i = i + 1
|
||||
if i > hx.size:
|
||||
break
|
||||
elif i == hx.size:
|
||||
break
|
||||
h = hx[i]
|
||||
n = 1
|
||||
for j in range(i+1, hx.size):
|
||||
if hx[j] == h:
|
||||
n = n + 1
|
||||
i = i + 1
|
||||
else:
|
||||
break
|
||||
if n == 1:
|
||||
outStr += ' {0:.2f},'.format(h)
|
||||
else:
|
||||
outStr += ' {0:d}*{1:.2f},'.format(n,h)
|
||||
return outStr[:-1]
|
||||
|
||||
if self.dim == 2:
|
||||
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr += printH(self.hx, outStr='\n hx:')
|
||||
outStr += printH(self.hy, outStr='\n hy:')
|
||||
elif self.dim == 3:
|
||||
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
|
||||
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
|
||||
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
|
||||
outStr += printH(self.hx, outStr='\n hx:')
|
||||
outStr += printH(self.hy, outStr='\n hy:')
|
||||
outStr += printH(self.hz, outStr='\n hz:')
|
||||
outStr += '\n nC: {0:d}'.format(self.nC)
|
||||
outStr += '\n Fill: %2.2f%%'%(self.fill*100)
|
||||
return outStr
|
||||
|
||||
@property
|
||||
def h(self):
|
||||
"""h is a list containing the cell widths of the tensor mesh in each dimension."""
|
||||
@@ -2109,8 +2163,8 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
|
||||
if showIt:plt.show()
|
||||
|
||||
def plotImage(self, I, ax=None, showIt=True):
|
||||
if self.dim == 3: raise Exception()
|
||||
def plotImage(self, I, ax=None, showIt=True, grid=False):
|
||||
if self.dim == 3: raise Exception('Use plot slice?')
|
||||
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
jet = cm = plt.get_cmap('jet')
|
||||
@@ -2120,12 +2174,102 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
ax.set_ylim((self.x0[1], self.h[1].sum()))
|
||||
for ii, node in enumerate(self._sortedCells):
|
||||
x0, sz = self._cellN(node), self._cellH(node)
|
||||
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k'))
|
||||
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=scalarMap.to_rgba(I[ii]), edgecolor='k' if grid else 'none'))
|
||||
# if text: ax.text(self.center[0],self.center[1],self.num)
|
||||
scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots
|
||||
plt.colorbar(scalarMap)
|
||||
if showIt: plt.show()
|
||||
|
||||
def plotSlice(self, v, vType='CC',
|
||||
normal='Z', ind=None, grid=True, view='real',
|
||||
ax=None, clim=None, showIt=False,
|
||||
pcolorOpts={},
|
||||
streamOpts={'color':'k'},
|
||||
gridOpts={'color':'k'}):
|
||||
|
||||
assert vType in ['CC','F','E']
|
||||
assert self.dim == 3
|
||||
|
||||
szSliceDim = len(getattr(self, 'h'+normal.lower())) #: Size of the sliced dimension
|
||||
if ind is None: ind = int(szSliceDim/2)
|
||||
assert type(ind) in [int, long], 'ind must be an integer'
|
||||
indLoc = getattr(self,'vectorCC'+normal.lower())[ind]
|
||||
normalInd = {'X':0,'Y':1,'Z':2}[normal]
|
||||
antiNormalInd = {'X':[1,2],'Y':[0,2],'Z':[0,1]}[normal]
|
||||
h2d = []
|
||||
x2d = []
|
||||
if 'X' not in normal:
|
||||
h2d.append(self.hx)
|
||||
x2d.append(self.x0[0])
|
||||
if 'Y' not in normal:
|
||||
h2d.append(self.hy)
|
||||
x2d.append(self.x0[1])
|
||||
if 'Z' not in normal:
|
||||
h2d.append(self.hz)
|
||||
x2d.append(self.x0[2])
|
||||
tM = TensorMesh(h2d, x2d) #: Temp Mesh
|
||||
|
||||
def getLocs(*args):
|
||||
if len(args) == 1:
|
||||
grids = (args[0],args[0],args[0])
|
||||
else:
|
||||
assert len(args) == 3
|
||||
grids = args
|
||||
one = np.ones((grids[0].shape[0],1))*indLoc
|
||||
if normal == 'X':
|
||||
return np.hstack((one, grids[0][:,[0]], grids[1][:,[1]]))
|
||||
if normal == 'Y':
|
||||
return np.hstack((grids[0][:,[0]], one, grids[1][:,[1]]))
|
||||
if normal == 'Z':
|
||||
return np.hstack((grids[0][:,[0]], grids[1][:,[1]], one))
|
||||
def doSlice(v):
|
||||
if vType == 'CC':
|
||||
P = self.getInterpolationMat(getLocs(tM.gridCC),'CC')
|
||||
elif vType in ['F', 'E']:
|
||||
Ps = []
|
||||
gridX = getLocs(getattr(tM, 'grid' + vType + 'x'))
|
||||
gridY = getLocs(getattr(tM, 'grid' + vType + 'y'))
|
||||
Ps += [self.getInterpolationMat(gridX,vType + ('y' if normal == 'X' else 'x'))]
|
||||
Ps += [self.getInterpolationMat(gridY,vType + ('y' if normal == 'Z' else 'z'))]
|
||||
P = sp.vstack(Ps)
|
||||
return P*v
|
||||
|
||||
v2d = doSlice(v)
|
||||
|
||||
if ax is None:
|
||||
fig = plt.figure()
|
||||
ax = plt.subplot(111)
|
||||
else:
|
||||
assert isinstance(ax, matplotlib.axes.Axes), "ax must be an matplotlib.axes.Axes"
|
||||
fig = ax.figure
|
||||
|
||||
out = tM._plotImage2D(v2d, vType=vType, view=view,
|
||||
ax=ax, clim=clim,
|
||||
pcolorOpts=pcolorOpts, streamOpts=streamOpts)
|
||||
|
||||
ax.set_xlabel('y' if normal == 'X' else 'x')
|
||||
ax.set_ylabel('y' if normal == 'Z' else 'z')
|
||||
ax.set_title('Slice %d, %s = %4.2f' % (ind,normal,indLoc))
|
||||
|
||||
if grid:
|
||||
_ = antiNormalInd
|
||||
X = []
|
||||
Y = []
|
||||
for cell in self._cells:
|
||||
p = self._pointer(cell)
|
||||
n, h = self._cellN(p), self._cellH(p)
|
||||
if n[normalInd]<indLoc and n[normalInd]+h[normalInd]>indLoc:
|
||||
X += [n[_[0]] , n[_[0]] + h[_[0]], n[_[0]] + h[_[0]], n[_[0]] , n[_[0]], np.nan]
|
||||
Y += [n[_[1]] , n[_[1]] , n[_[1]] + h[_[1]], n[_[1]] + h[_[1]], n[_[1]], np.nan]
|
||||
out = list(out)
|
||||
out += ax.plot(X,Y, **gridOpts)
|
||||
if len(out) > 2: # this is not robust, searching for the streamlines would be better
|
||||
out[1].lines.set_zorder(200)
|
||||
out[1].arrows.set_zorder(201)
|
||||
if showIt: plt.show()
|
||||
return tuple(out)
|
||||
|
||||
|
||||
class Cell(object):
|
||||
def __init__(self, mesh, index, pointer):
|
||||
self.mesh = mesh
|
||||
@@ -2192,27 +2336,24 @@ if __name__ == '__main__':
|
||||
|
||||
def function(cell):
|
||||
r = cell.center - np.array([0.5]*len(cell.center))
|
||||
dist1 = np.sqrt(r.dot(r)) - 0.08
|
||||
dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
|
||||
dist = np.sqrt(r.dot(r))
|
||||
# dist2 = np.abs(cell.center[-1] - topo(cell.center[0]))
|
||||
|
||||
dist = min([dist1,dist2])
|
||||
# dist = min([dist1,dist2])
|
||||
# if dist < 0.05:
|
||||
# return 5
|
||||
if dist < 0.05:
|
||||
return 6
|
||||
if dist < 0.2:
|
||||
if dist < 0.1:
|
||||
return 5
|
||||
if dist < 0.3:
|
||||
if dist < 0.2:
|
||||
return 4
|
||||
if dist < 1.0:
|
||||
if dist < 0.4:
|
||||
return 3
|
||||
else:
|
||||
return 0
|
||||
return 2
|
||||
|
||||
# T = TreeMesh([[(1,128)],[(1,128)],[(1,128)]],levels=7)
|
||||
# T = TreeMesh([128,128,128],levels=7)
|
||||
T = TreeMesh([128,128,128])
|
||||
# T = TreeMesh([64,64],levels=6)
|
||||
T = TreeMesh([4,4,4],levels=2)
|
||||
# T = TreeMesh([4,4,4],levels=2)
|
||||
# T = TreeMesh([[(1,128)],[(1,128)]],levels=7)
|
||||
# T.refine(lambda xc:2, balance=False)
|
||||
# T._index([0,0,0])
|
||||
@@ -2220,14 +2361,12 @@ if __name__ == '__main__':
|
||||
|
||||
|
||||
# tic = time.time()
|
||||
# T.refine(function)#, balance=False)
|
||||
T.refine(function)#, balance=False)
|
||||
# print time.time() - tic
|
||||
# print T.nC
|
||||
T.plotSlice(np.log(T.vol))#np.random.rand(T.nC))
|
||||
|
||||
print T.nC
|
||||
|
||||
P = T.getInterpolationMat([0.2,0,0], 'Ex')
|
||||
print P.todense()
|
||||
plt.show()
|
||||
blah
|
||||
|
||||
# T.plotImage(np.arange(len(T.vol)),showIt=True)
|
||||
|
||||
@@ -216,6 +216,7 @@ class TensorView(object):
|
||||
if ind is None: ind = int(szSliceDim/2)
|
||||
assert type(ind) in [int, long], 'ind must be an integer'
|
||||
|
||||
assert not (v.dtype == complex and view == 'vec'), 'Can not plot a complex vector.'
|
||||
# The slicing and plotting code!!
|
||||
|
||||
def getIndSlice(v):
|
||||
|
||||
Reference in New Issue
Block a user