diff --git a/SimPEG/mesh/QuadTreeMesh.py b/SimPEG/mesh/QuadTreeMesh.py index 6fe6189d..1ac0f5eb 100644 --- a/SimPEG/mesh/QuadTreeMesh.py +++ b/SimPEG/mesh/QuadTreeMesh.py @@ -1,5 +1,7 @@ from SimPEG import np, sp, utils import matplotlib.pyplot as plt +import matplotlib.colors as colors +import matplotlib.cm as cmx @@ -33,7 +35,7 @@ class TreeFace(object): self.mesh = mesh self.children = None self.numFace = None - self.x0 = np.array(x0,dtype=float) + self.x0 = np.array(x0, dtype=float) self.faceType = faceType self.sz = sz self.dim = dim @@ -95,9 +97,9 @@ class TreeNode(object): self.mesh = mesh self.x0 = np.array(x0, dtype=float) + self.sz = np.array(sz, dtype=float) self.dim = dim self.depth = depth - self.sz = np.array(sz, dtype=float) self.parent = parent self.children = None self.numCell = None @@ -181,19 +183,11 @@ class TreeNode(object): def vol(self): return self.sz.prod() - def viz(self, ax, text=True): - if self.isleaf: - x0, sz = self.x0, self.sz - corners = np.c_[np.r_[x0[0] , x0[1] ], - np.r_[x0[0]+sz[0], x0[1] ], - np.r_[x0[0]+sz[0], x0[1]+sz[1]], - np.r_[x0[0] , x0[1]+sz[1]], - np.r_[x0[0] , x0[1] ]].T - ax.plot(corners[:,0],corners[:,1], 'b') - if self.numCell is not None and text: - ax.text(x0[0]+sz[0]/2,x0[1]+sz[1]/2,'%d'%self.numCell) - else: - [node.viz(ax) for node in self.children.flatten('F')] + def viz(self, ax, color='none', text=False): + if not self.isleaf: return + x0, sz = self.x0, self.sz + ax.add_patch(plt.Rectangle((x0[0], x0[1]), sz[0], sz[1], facecolor=color, edgecolor='k')) + if text: ax.text(self.gridCC[0],self.gridCC[1],self.numCell) @@ -206,6 +200,7 @@ class QuadTreeMesh(object): self.faceY = set() self.cells = set() self.x0 = np.array(x0,dtype=float) + self.h = h self.children = np.empty([hi.size for hi in h],dtype=TreeNode) for i in range(h[0].size): for j in range(h[1].size): @@ -242,11 +237,9 @@ class QuadTreeMesh(object): self.isNumbered = True - def viz(self, ax=None, text=True): - if ax is None: ax = plt.subplot(111) - # [node.viz(ax, text=text) for node in self.cells] - [node.viz(ax, text=text) for node in self.faces] + @property + def dim(self): return len(self.x0) @property def nC(self): return len(self.cells) @property @@ -291,6 +284,28 @@ class QuadTreeMesh(object): return self._faceDiv + def plotGrid(self, ax=None, text=True, plotC=True, plotF=False): + if ax is None: ax = plt.subplot(111) + + if plotC: [node.viz(ax, text=text) for node in self.cells] + if plotF: [node.viz(ax, text=text) for node in self.faces] + ax.set_xlim((self.x0[0], self.h[0].sum())) + ax.set_ylim((self.x0[1], self.h[1].sum())) + + + def plotImage(self, I, ax=None): + if ax is None: ax = plt.subplot(111) + jet = cm = plt.get_cmap('jet') + cNorm = colors.Normalize(vmin=I.min(), vmax=I.max()) + scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet) + ax.set_xlim((self.x0[0], self.h[0].sum())) + ax.set_ylim((self.x0[1], self.h[1].sum())) + for ii, node in enumerate(self.sortedCells): + node.viz(ax=ax, color=scalarMap.to_rgba(I[ii])) + scalarMap._A = [] # http://stackoverflow.com/questions/8342549/matplotlib-add-colorbar-to-a-sequence-of-line-plots + plt.colorbar(scalarMap) + + if __name__ == '__main__': M = QuadTreeMesh([np.ones(x) for x in [4,10]]) @@ -308,9 +323,16 @@ if __name__ == '__main__': M.refine(function) DIV = M.faceDiv - plt.subplot(211) - plt.spy(DIV) + # plt.subplot(211) + # plt.spy(DIV) + M.plotGrid(ax=plt.subplot(111),text=True) - M.viz(ax=plt.subplot(212),text=False) + q = np.zeros(M.nC) + q[208] = -1.0 + q[291] = 1.0 + b = utils.Solver(-DIV*DIV.T).solve(q) + plt.figure() + M.plotImage(b) # plt.gca().invert_yaxis() + print M.vol plt.show()