diff --git a/SimPEG/Mesh/CylMesh.py b/SimPEG/Mesh/CylMesh.py index 3214a9ed..93cd5451 100644 --- a/SimPEG/Mesh/CylMesh.py +++ b/SimPEG/Mesh/CylMesh.py @@ -4,9 +4,10 @@ from scipy.constants import pi from SimPEG.Utils import mkvc, ndgrid, sdiag, kron3, speye, ddx, av, avExtrap from TensorMesh import BaseTensorMesh from InnerProducts import InnerProducts +from View import CylView -class CylMesh(BaseTensorMesh, InnerProducts): +class CylMesh(BaseTensorMesh, InnerProducts, CylView): """ CylMesh is a mesh class for cylindrical problems @@ -293,3 +294,14 @@ class CylMesh(BaseTensorMesh, InnerProducts): # kron3(speye(n[2]), av(n[1]), speye(n[0])), # kron3(av(n[2]), speye(n[1]), speye(n[0]))), format="csr") return self._aveF2CC + + +if __name__ == '__main__': + + from SimPEG import * + hx = np.r_[1,1,0.5] + hz = np.r_[2,1] + M = Mesh.CylMesh([hx, 1,hz], x0='00N') + + M.plotImage(np.random.rand(M.nC), showIt=False) + M.plotGrid(centers=True, showIt=True) diff --git a/SimPEG/Mesh/TensorMesh.py b/SimPEG/Mesh/TensorMesh.py index d9a14aaa..cd46fb85 100644 --- a/SimPEG/Mesh/TensorMesh.py +++ b/SimPEG/Mesh/TensorMesh.py @@ -13,7 +13,7 @@ class BaseTensorMesh(BaseRectangularMesh): _unitDimensions = [1, 1, 1] def __init__(self, h_in, x0_in=None): - assert type(h_in) is list, 'h_in must be a list' + assert type(h_in) in [list, tuple], 'h_in must be a list' assert len(h_in) in [1,2,3], 'h_in must be of dimension 1, 2, or 3' h = range(len(h_in)) for i, h_i in enumerate(h_in): diff --git a/SimPEG/Mesh/View.py b/SimPEG/Mesh/View.py index a4753b5e..6663b44f 100644 --- a/SimPEG/Mesh/View.py +++ b/SimPEG/Mesh/View.py @@ -518,7 +518,46 @@ class TensorView(object): return animate(fig, animateFrame, frames=len(frames)) +class CylView(object): + def _plotCylTensorMesh(self, plotType, *args, **kwargs): + + if not self.isSymmetric: + raise Exception('We have not yet implemented this type of view.') + assert plotType in ['plotImage', 'plotGrid'] + # Hackity Hack: + # Just create a TM and use its view. + from SimPEG.Mesh import TensorMesh + M = TensorMesh([self.hx, self.hz], x0=[self.x0[0], self.x0[2]]) + + ax = kwargs.get('ax', None) + if ax is None: + fig = plt.figure() + ax = plt.subplot(111) + kwargs['ax'] = ax + else: + assert isinstance(ax, matplotlib.axes.Axes), "ax must be an matplotlib.axes.Axes" + fig = ax.figure + + # Don't show things in the TM.plotImage + showIt = kwargs.get('showIt', False) + kwargs['showIt'] = False + + out = getattr(M, plotType)(*args, **kwargs) + + ax.set_xlabel('x') + ax.set_ylabel('z') + + if showIt: plt.show() + + return out + + + def plotGrid(self, *args, **kwargs): + return self._plotCylTensorMesh('plotGrid', *args, **kwargs) + + def plotImage(self, *args, **kwargs): + return self._plotCylTensorMesh('plotImage', *args, **kwargs) class LomView(object): """