Visualization.

This commit is contained in:
rowanc1
2013-12-20 11:44:20 -07:00
parent a01b805713
commit e2b345ae89
+44 -22
View File
@@ -1,5 +1,7 @@
from SimPEG import np, sp, utils
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
@@ -33,7 +35,7 @@ class TreeFace(object):
self.mesh = mesh
self.children = None
self.numFace = None
self.x0 = np.array(x0,dtype=float)
self.x0 = np.array(x0, dtype=float)
self.faceType = faceType
self.sz = sz
self.dim = dim
@@ -95,9 +97,9 @@ class TreeNode(object):
self.mesh = mesh
self.x0 = np.array(x0, dtype=float)
self.sz = np.array(sz, dtype=float)
self.dim = dim
self.depth = depth
self.sz = np.array(sz, dtype=float)
self.parent = parent
self.children = None
self.numCell = None
@@ -181,19 +183,11 @@ class TreeNode(object):
def vol(self): return self.sz.prod()
def viz(self, ax, text=True):
if self.isleaf:
x0, sz = self.x0, self.sz
corners = np.c_[np.r_[x0[0] , x0[1] ],
np.r_[x0[0]+sz[0], x0[1] ],
np.r_[x0[0]+sz[0], x0[1]+sz[1]],
np.r_[x0[0] , x0[1]+sz[1]],
np.r_[x0[0] , x0[1] ]].T
ax.plot(corners[:,0],corners[:,1], 'b')
if self.numCell is not None and text:
ax.text(x0[0]+sz[0]/2,x0[1]+sz[1]/2,'%d'%self.numCell)
else:
[node.viz(ax) for node in self.children.flatten('F')]
def viz(self, ax, color='none', text=False):
if not self.isleaf: return
x0, sz = self.x0, self.sz
ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=color, edgecolor='k'))
if text: ax.text(self.gridCC[0],self.gridCC[1],self.numCell)
@@ -206,6 +200,7 @@ class QuadTreeMesh(object):
self.faceY = set()
self.cells = set()
self.x0 = np.array(x0,dtype=float)
self.h = h
self.children = np.empty([hi.size for hi in h],dtype=TreeNode)
for i in range(h[0].size):
for j in range(h[1].size):
@@ -242,11 +237,9 @@ class QuadTreeMesh(object):
self.isNumbered = True
def viz(self, ax=None, text=True):
if ax is None: ax = plt.subplot(111)
# [node.viz(ax, text=text) for node in self.cells]
[node.viz(ax, text=text) for node in self.faces]
@property
def dim(self): return len(self.x0)
@property
def nC(self): return len(self.cells)
@property
@@ -291,6 +284,28 @@ class QuadTreeMesh(object):
return self._faceDiv
def plotGrid(self, ax=None, text=True, plotC=True, plotF=False):
if ax is None: ax = plt.subplot(111)
if plotC: [node.viz(ax, text=text) for node in self.cells]
if plotF: [node.viz(ax, text=text) for node in self.faces]
ax.set_xlim((self.x0[0], self.h[0].sum()))
ax.set_ylim((self.x0[1], self.h[1].sum()))
def plotImage(self, I, ax=None):
if ax is None: ax = plt.subplot(111)
jet = cm = plt.get_cmap('jet')
cNorm = colors.Normalize(vmin=I.min(), vmax=I.max())
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)
ax.set_xlim((self.x0[0], self.h[0].sum()))
ax.set_ylim((self.x0[1], self.h[1].sum()))
for ii, node in enumerate(self.sortedCells):
node.viz(ax=ax, color=scalarMap.to_rgba(I[ii]))
scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots
plt.colorbar(scalarMap)
if __name__ == '__main__':
M = QuadTreeMesh([np.ones(x) for x in [4,10]])
@@ -308,9 +323,16 @@ if __name__ == '__main__':
M.refine(function)
DIV = M.faceDiv
plt.subplot(211)
plt.spy(DIV)
# plt.subplot(211)
# plt.spy(DIV)
M.plotGrid(ax=plt.subplot(111),text=True)
M.viz(ax=plt.subplot(212),text=False)
q = np.zeros(M.nC)
q[208] = -1.0
q[291] = 1.0
b = utils.Solver(-DIV*DIV.T).solve(q)
plt.figure()
M.plotImage(b)
# plt.gca().invert_yaxis()
print M.vol
plt.show()