mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 18:25:42 +08:00
Updates to inheritance and faster innerproducts on treemesh.
This commit is contained in:
@@ -2,12 +2,12 @@ import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from scipy.constants import pi
|
||||
from SimPEG.Utils import mkvc, ndgrid, sdiag, kron3, speye, spzeros, ddx, av, avExtrap
|
||||
from TensorMesh import BaseTensorMesh
|
||||
from TensorMesh import BaseTensorMesh, BaseRectangularMesh
|
||||
from InnerProducts import InnerProducts
|
||||
from View import CylView
|
||||
|
||||
|
||||
class CylMesh(BaseTensorMesh, InnerProducts, CylView):
|
||||
class CylMesh(BaseTensorMesh, BaseRectangularMesh, InnerProducts, CylView):
|
||||
"""
|
||||
CylMesh is a mesh class for cylindrical problems
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from SimPEG import Utils, np, sp
|
||||
from BaseMesh import BaseRectangularMesh
|
||||
from BaseMesh import BaseMesh, BaseRectangularMesh
|
||||
from View import TensorView
|
||||
from DiffOperators import DiffOperators
|
||||
from InnerProducts import InnerProducts
|
||||
|
||||
class BaseTensorMesh(BaseRectangularMesh):
|
||||
class BaseTensorMesh(BaseMesh):
|
||||
|
||||
__metaclass__ = Utils.SimPEGMetaClass
|
||||
|
||||
@@ -42,7 +42,10 @@ class BaseTensorMesh(BaseRectangularMesh):
|
||||
else:
|
||||
raise Exception("x0[%i] must be a scalar or '0' to be zero, 'C' to center, or 'N' to be negative." % i)
|
||||
|
||||
BaseRectangularMesh.__init__(self, np.array([x.size for x in h]), x0)
|
||||
if isinstance(self, BaseRectangularMesh):
|
||||
BaseRectangularMesh.__init__(self, np.array([x.size for x in h]), x0)
|
||||
else:
|
||||
BaseMesh.__init__(self, np.array([x.size for x in h]), x0)
|
||||
|
||||
# Ensure h contains 1D vectors
|
||||
self._h = [Utils.mkvc(x.astype(float)) for x in h]
|
||||
@@ -356,7 +359,7 @@ class BaseTensorMesh(BaseRectangularMesh):
|
||||
|
||||
|
||||
|
||||
class TensorMesh(BaseTensorMesh, TensorView, DiffOperators, InnerProducts):
|
||||
class TensorMesh(BaseTensorMesh, BaseRectangularMesh, TensorView, DiffOperators, InnerProducts):
|
||||
"""
|
||||
TensorMesh is a mesh class that deals with tensor product meshes.
|
||||
|
||||
|
||||
+8
-149
@@ -98,51 +98,23 @@ import matplotlib.cm as cmx
|
||||
|
||||
import TreeUtils
|
||||
from InnerProducts import InnerProducts
|
||||
from BaseMesh import BaseMesh
|
||||
from TensorMesh import TensorMesh
|
||||
from TensorMesh import TensorMesh, BaseTensorMesh
|
||||
import time
|
||||
|
||||
MAX_BITS = 20
|
||||
|
||||
class TreeMesh(BaseMesh, InnerProducts):
|
||||
class TreeMesh(BaseTensorMesh, InnerProducts):
|
||||
|
||||
_meshType = 'TREE'
|
||||
|
||||
def __init__(self, h_in, x0_in=None, levels=None):
|
||||
assert type(h_in) is list, 'h_in must be a list'
|
||||
assert len(h_in) in [2,3], "There is only support for TreeMesh in 2D or 3D."
|
||||
def __init__(self, h, x0=None, levels=None):
|
||||
assert type(h) is list, 'h must be a list'
|
||||
assert len(h) in [2,3], "There is only support for TreeMesh in 2D or 3D."
|
||||
|
||||
h = range(len(h_in))
|
||||
for i, h_i in enumerate(h_in):
|
||||
if type(h_i) in [int, long, float]:
|
||||
# This gives you something over the unit cube.
|
||||
h_i = np.ones(int(h_i))/int(h_i)
|
||||
elif type(h_i) is list:
|
||||
h_i = Utils.meshTensor(h_i)
|
||||
assert isinstance(h_i, np.ndarray), ("h[%i] is not a numpy array." % i)
|
||||
assert len(h_i.shape) == 1, ("h[%i] must be a 1D numpy array." % i)
|
||||
if levels is None:levels = int(np.log2(len(h_i)))
|
||||
assert len(h_i) == 2**levels, "must make h and levels match"
|
||||
h[i] = h_i[:] # make a copy.
|
||||
self._h = h
|
||||
BaseTensorMesh.__init__(self, h, x0)
|
||||
|
||||
x0 = np.zeros(len(h))
|
||||
if x0_in is not None:
|
||||
assert len(h) == len(x0_in), "Dimension mismatch. x0 != len(h)"
|
||||
for i in range(len(h)):
|
||||
x_i, h_i = x0_in[i], h[i]
|
||||
if Utils.isScalar(x_i):
|
||||
x0[i] = x_i
|
||||
elif x_i == '0':
|
||||
x0[i] = 0.0
|
||||
elif x_i == 'C':
|
||||
x0[i] = -h_i.sum()*0.5
|
||||
elif x_i == 'N':
|
||||
x0[i] = -h_i.sum()
|
||||
else:
|
||||
raise Exception("x0[%i] must be a scalar or '0' to be zero, 'C' to center, or 'N' to be negative." % i)
|
||||
|
||||
BaseMesh.__init__(self, [len(_) for _ in h], x0)
|
||||
if levels is None:levels = int(np.log2(len(self._h[0])))
|
||||
assert np.all(len(_) == 2**levels for _ in self._h), "must make h and levels match"
|
||||
|
||||
self._levels = levels
|
||||
self._levelBits = int(np.ceil(np.sqrt(levels)))+1
|
||||
@@ -238,95 +210,6 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
outStr += '\n Fill: %2.2f%%'%(self.fill*100)
|
||||
return outStr
|
||||
|
||||
@property
|
||||
def h(self):
|
||||
"""h is a list containing the cell widths of the tensor mesh in each dimension."""
|
||||
return self._h
|
||||
|
||||
@property
|
||||
def hx(self):
|
||||
"Width of cells in the x direction"
|
||||
return self._h[0]
|
||||
|
||||
@property
|
||||
def hy(self):
|
||||
"Width of cells in the y direction"
|
||||
return self._h[1]
|
||||
|
||||
@property
|
||||
def hz(self):
|
||||
"Width of cells in the z direction"
|
||||
return None if self.dim < 3 else self._h[2]
|
||||
|
||||
@property
|
||||
def vectorNx(self):
|
||||
"""Nodal grid vector (1D) in the x direction."""
|
||||
return np.r_[0., self.hx.cumsum()] + self.x0[0]
|
||||
|
||||
@property
|
||||
def vectorNy(self):
|
||||
"""Nodal grid vector (1D) in the y direction."""
|
||||
return np.r_[0., self.hy.cumsum()] + self.x0[1]
|
||||
|
||||
@property
|
||||
def vectorNz(self):
|
||||
"""Nodal grid vector (1D) in the z direction."""
|
||||
return None if self.dim < 3 else np.r_[0., self.hz.cumsum()] + self.x0[2]
|
||||
|
||||
@property
|
||||
def vectorCCx(self):
|
||||
"""Cell-centered grid vector (1D) in the x direction."""
|
||||
return np.r_[0, self.hx[:-1].cumsum()] + self.hx*0.5 + self.x0[0]
|
||||
|
||||
@property
|
||||
def vectorCCy(self):
|
||||
"""Cell-centered grid vector (1D) in the y direction."""
|
||||
return np.r_[0, self.hy[:-1].cumsum()] + self.hy*0.5 + self.x0[1]
|
||||
|
||||
@property
|
||||
def vectorCCz(self):
|
||||
"""Cell-centered grid vector (1D) in the z direction."""
|
||||
return None if self.dim < 3 else np.r_[0, self.hz[:-1].cumsum()] + self.hz*0.5 + self.x0[2]
|
||||
|
||||
def getTensor(self, key):
|
||||
""" Returns a tensor list.
|
||||
|
||||
:param str key: What tensor (see below)
|
||||
:rtype: list
|
||||
:return: list of the tensors that make up the mesh.
|
||||
|
||||
key can be::
|
||||
|
||||
'CC' -> scalar field defined on cell centers
|
||||
'N' -> scalar field defined on nodes
|
||||
'Fx' -> x-component of field defined on faces
|
||||
'Fy' -> y-component of field defined on faces
|
||||
'Fz' -> z-component of field defined on faces
|
||||
'Ex' -> x-component of field defined on edges
|
||||
'Ey' -> y-component of field defined on edges
|
||||
'Ez' -> z-component of field defined on edges
|
||||
|
||||
"""
|
||||
|
||||
if key == 'Fx':
|
||||
ten = [self.vectorNx , self.vectorCCy, self.vectorCCz]
|
||||
elif key == 'Fy':
|
||||
ten = [self.vectorCCx, self.vectorNy , self.vectorCCz]
|
||||
elif key == 'Fz':
|
||||
ten = [self.vectorCCx, self.vectorCCy, self.vectorNz ]
|
||||
elif key == 'Ex':
|
||||
ten = [self.vectorCCx, self.vectorNy , self.vectorNz ]
|
||||
elif key == 'Ey':
|
||||
ten = [self.vectorNx , self.vectorCCy, self.vectorNz ]
|
||||
elif key == 'Ez':
|
||||
ten = [self.vectorNx , self.vectorNy , self.vectorCCz]
|
||||
elif key == 'CC':
|
||||
ten = [self.vectorCCx, self.vectorCCy, self.vectorCCz]
|
||||
elif key == 'N':
|
||||
ten = [self.vectorNx , self.vectorNy , self.vectorNz ]
|
||||
|
||||
return [t for t in ten if t is not None]
|
||||
|
||||
@property
|
||||
def nC(self): return len(self._cells)
|
||||
|
||||
@@ -1892,30 +1775,6 @@ class TreeMesh(BaseMesh, InnerProducts):
|
||||
return self._getEdgeP(xEdge, yEdge, zEdge)
|
||||
return Pxxx
|
||||
|
||||
|
||||
def isInside(self, pts, locType='N'):
|
||||
"""
|
||||
Determines if a set of points are inside a mesh.
|
||||
|
||||
:param numpy.ndarray pts: Location of points to test
|
||||
:rtype numpy.ndarray
|
||||
:return inside, numpy array of booleans
|
||||
"""
|
||||
pts = Utils.asArray_N_x_Dim(pts, self.dim)
|
||||
|
||||
tensors = self.getTensor(locType)
|
||||
|
||||
if locType == 'N' and self._meshType == 'CYL':
|
||||
#NOTE: for a CYL mesh we add a node to check if we are inside in the radial direction!
|
||||
tensors[0] = np.r_[0.,tensors[0]]
|
||||
tensors[1] = np.r_[tensors[1], 2.0*np.pi]
|
||||
|
||||
inside = np.ones(pts.shape[0],dtype=bool)
|
||||
for i, tensor in enumerate(tensors):
|
||||
TOL = np.diff(tensor).min() * 1.0e-10
|
||||
inside = inside & (pts[:,i] >= tensor.min()-TOL) & (pts[:,i] <= tensor.max()+TOL)
|
||||
return inside
|
||||
|
||||
def point2index(self, locs):
|
||||
locs = Utils.asArray_N_x_Dim(locs, self.dim)
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class TestCurl(Tests.OrderTest):
|
||||
class TestTreeInnerProducts(Tests.OrderTest):
|
||||
"""Integrate an function over a unit cube domain using edgeInnerProducts and faceInnerProducts."""
|
||||
|
||||
meshTypes = ['uniformTree'] #['uniformTensorMesh', 'uniformCurv', 'rotateCurv']
|
||||
meshTypes = ['uniformTree', 'notatreeTree'] #['uniformTensorMesh', 'uniformCurv', 'rotateCurv']
|
||||
meshDimension = 3
|
||||
meshSizes = [4, 8]
|
||||
|
||||
|
||||
@@ -218,18 +218,18 @@ class TestInnerProductsDerivs(unittest.TestCase):
|
||||
def test_FaceIP_3D_tensor_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8, 8],6, False, 'Tree'))
|
||||
|
||||
# def test_FaceIP_2D_float_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8],0, True, 'Tree'))
|
||||
# def test_FaceIP_3D_float_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8, 8],0, True, 'Tree'))
|
||||
# def test_FaceIP_2D_isotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8],1, True, 'Tree'))
|
||||
# def test_FaceIP_3D_isotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8, 8],1, True, 'Tree'))
|
||||
# def test_FaceIP_2D_anisotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8],2, True, 'Tree'))
|
||||
# def test_FaceIP_3D_anisotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestFace([8, 8, 8],3, True, 'Tree'))
|
||||
def test_FaceIP_2D_float_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8],0, True, 'Tree'))
|
||||
def test_FaceIP_3D_float_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8, 8],0, True, 'Tree'))
|
||||
def test_FaceIP_2D_isotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8],1, True, 'Tree'))
|
||||
def test_FaceIP_3D_isotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8, 8],1, True, 'Tree'))
|
||||
def test_FaceIP_2D_anisotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8],2, True, 'Tree'))
|
||||
def test_FaceIP_3D_anisotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestFace([8, 8, 8],3, True, 'Tree'))
|
||||
|
||||
# def test_EdgeIP_2D_float_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8],0, False, 'Tree'))
|
||||
@@ -250,16 +250,16 @@ class TestInnerProductsDerivs(unittest.TestCase):
|
||||
|
||||
# def test_EdgeIP_2D_float_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8],0, True, 'Tree'))
|
||||
# def test_EdgeIP_3D_float_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8, 8],0, True, 'Tree'))
|
||||
def test_EdgeIP_3D_float_fast_Tree(self):
|
||||
self.assertTrue(self.doTestEdge([8, 8, 8],0, True, 'Tree'))
|
||||
# def test_EdgeIP_2D_isotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8],1, True, 'Tree'))
|
||||
# def test_EdgeIP_3D_isotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8, 8],1, True, 'Tree'))
|
||||
def test_EdgeIP_3D_isotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestEdge([8, 8, 8],1, True, 'Tree'))
|
||||
# def test_EdgeIP_2D_anisotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8],2, True, 'Tree'))
|
||||
# def test_EdgeIP_3D_anisotropic_fast_Tree(self):
|
||||
# self.assertTrue(self.doTestEdge([8, 8, 8],3, True, 'Tree'))
|
||||
def test_EdgeIP_3D_anisotropic_fast_Tree(self):
|
||||
self.assertTrue(self.doTestEdge([8, 8, 8],3, True, 'Tree'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user