mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-01 04:00:05 +08:00
Visualization.
This commit is contained in:
+44
-22
@@ -1,5 +1,7 @@
|
||||
from SimPEG import np, sp, utils
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as colors
|
||||
import matplotlib.cm as cmx
|
||||
|
||||
|
||||
|
||||
@@ -33,7 +35,7 @@ class TreeFace(object):
|
||||
self.mesh = mesh
|
||||
self.children = None
|
||||
self.numFace = None
|
||||
self.x0 = np.array(x0,dtype=float)
|
||||
self.x0 = np.array(x0, dtype=float)
|
||||
self.faceType = faceType
|
||||
self.sz = sz
|
||||
self.dim = dim
|
||||
@@ -95,9 +97,9 @@ class TreeNode(object):
|
||||
|
||||
self.mesh = mesh
|
||||
self.x0 = np.array(x0, dtype=float)
|
||||
self.sz = np.array(sz, dtype=float)
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.sz = np.array(sz, dtype=float)
|
||||
self.parent = parent
|
||||
self.children = None
|
||||
self.numCell = None
|
||||
@@ -181,19 +183,11 @@ class TreeNode(object):
|
||||
def vol(self): return self.sz.prod()
|
||||
|
||||
|
||||
def viz(self, ax, text=True):
|
||||
if self.isleaf:
|
||||
x0, sz = self.x0, self.sz
|
||||
corners = np.c_[np.r_[x0[0] , x0[1] ],
|
||||
np.r_[x0[0]+sz[0], x0[1] ],
|
||||
np.r_[x0[0]+sz[0], x0[1]+sz[1]],
|
||||
np.r_[x0[0] , x0[1]+sz[1]],
|
||||
np.r_[x0[0] , x0[1] ]].T
|
||||
ax.plot(corners[:,0],corners[:,1], 'b')
|
||||
if self.numCell is not None and text:
|
||||
ax.text(x0[0]+sz[0]/2,x0[1]+sz[1]/2,'%d'%self.numCell)
|
||||
else:
|
||||
[node.viz(ax) for node in self.children.flatten('F')]
|
||||
def viz(self, ax, color='none', text=False):
|
||||
if not self.isleaf: return
|
||||
x0, sz = self.x0, self.sz
|
||||
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=color, edgecolor='k'))
|
||||
if text: ax.text(self.gridCC[0],self.gridCC[1],self.numCell)
|
||||
|
||||
|
||||
|
||||
@@ -206,6 +200,7 @@ class QuadTreeMesh(object):
|
||||
self.faceY = set()
|
||||
self.cells = set()
|
||||
self.x0 = np.array(x0,dtype=float)
|
||||
self.h = h
|
||||
self.children = np.empty([hi.size for hi in h],dtype=TreeNode)
|
||||
for i in range(h[0].size):
|
||||
for j in range(h[1].size):
|
||||
@@ -242,11 +237,9 @@ class QuadTreeMesh(object):
|
||||
|
||||
self.isNumbered = True
|
||||
|
||||
def viz(self, ax=None, text=True):
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
# [node.viz(ax, text=text) for node in self.cells]
|
||||
[node.viz(ax, text=text) for node in self.faces]
|
||||
|
||||
@property
|
||||
def dim(self): return len(self.x0)
|
||||
@property
|
||||
def nC(self): return len(self.cells)
|
||||
@property
|
||||
@@ -291,6 +284,28 @@ class QuadTreeMesh(object):
|
||||
return self._faceDiv
|
||||
|
||||
|
||||
def plotGrid(self, ax=None, text=True, plotC=True, plotF=False):
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
|
||||
if plotC: [node.viz(ax, text=text) for node in self.cells]
|
||||
if plotF: [node.viz(ax, text=text) for node in self.faces]
|
||||
ax.set_xlim((self.x0[0], self.h[0].sum()))
|
||||
ax.set_ylim((self.x0[1], self.h[1].sum()))
|
||||
|
||||
|
||||
def plotImage(self, I, ax=None):
|
||||
if ax is None: ax = plt.subplot(111)
|
||||
jet = cm = plt.get_cmap('jet')
|
||||
cNorm = colors.Normalize(vmin=I.min(), vmax=I.max())
|
||||
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)
|
||||
ax.set_xlim((self.x0[0], self.h[0].sum()))
|
||||
ax.set_ylim((self.x0[1], self.h[1].sum()))
|
||||
for ii, node in enumerate(self.sortedCells):
|
||||
node.viz(ax=ax, color=scalarMap.to_rgba(I[ii]))
|
||||
scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots
|
||||
plt.colorbar(scalarMap)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
M = QuadTreeMesh([np.ones(x) for x in [4,10]])
|
||||
@@ -308,9 +323,16 @@ if __name__ == '__main__':
|
||||
M.refine(function)
|
||||
|
||||
DIV = M.faceDiv
|
||||
plt.subplot(211)
|
||||
plt.spy(DIV)
|
||||
# plt.subplot(211)
|
||||
# plt.spy(DIV)
|
||||
M.plotGrid(ax=plt.subplot(111),text=True)
|
||||
|
||||
M.viz(ax=plt.subplot(212),text=False)
|
||||
q = np.zeros(M.nC)
|
||||
q[208] = -1.0
|
||||
q[291] = 1.0
|
||||
b = utils.Solver(-DIV*DIV.T).solve(q)
|
||||
plt.figure()
|
||||
M.plotImage(b)
|
||||
# plt.gca().invert_yaxis()
|
||||
print M.vol
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user