bug fixes.

This commit is contained in:
rowanc1
2014-02-10 17:35:55 -08:00
parent db73e4c018
commit 3bd4b0ee59
2 changed files with 53 additions and 17 deletions
+11 -6
View File
@@ -212,7 +212,7 @@ class TreeFace(TreeObject):
def index(self):
if not self.mesh.isNumbered: raise Exception('Mesh is not numbered.')
if self.isleaf: return np.r_[self.num]
return np.concatenate([face.index for face in self.children])
return np.concatenate([face.index for face in self.children.flatten(order='F')])
@property
def area(self):
@@ -590,7 +590,7 @@ class TreeCell(TreeObject):
j = self.faces[face].index
i = j*0+self.num
v = j*0+1
if 'p' in face:
if 'm' in face:
v *= -1
I, J, V = np.r_[I,i], np.r_[J,j], np.r_[V,v]
return I, J, V
@@ -671,7 +671,7 @@ class TreeMesh(object):
x0i = (np.r_[x0[0], h[0][:i]]).sum()
x0j = (np.r_[x0[1], h[1][:j]]).sum()
x0k = (np.r_[x0[2], h[2][:k]]).sum()
self.children[i][j] = TreeCell(self, x0=[x0i, x0j, x0k], depth=0, sz=[h[0][i], h[1][j], h[2][k]], fXm=fXm, fYm=fYm, fZm=fZm)
self.children[i][j][k] = TreeCell(self, x0=[x0i, x0j, x0k], depth=0, sz=[h[0][i], h[1][j], h[2][k]], fXm=fXm, fYm=fYm, fZm=fZm)
isNumbered = Utils.dependentProperty('_isNumbered', False, ['_faceDiv'], 'Setting this to False will delete all operators.')
@@ -833,19 +833,24 @@ class TreeMesh(object):
@property
def area(self):
self.number()
return np.concatenate(([face.area for face in self.sortedFaceX],[face.area for face in self.sortedFaceY]))
if self.dim == 2:
faces = self.sortedFaceX + self.sortedFaceY
elif self.dim == 3:
faces = self.sortedFaceX + self.sortedFaceY + self.sortedFaceZ
return np.array([face.area for face in faces], dtype=float)
@property
def faceDiv(self):
if getattr(self, '_faceDiv', None) is None:
self.number()
# TODO: Preallocate!
I, J, V = np.empty(0), np.empty(0), np.empty(0)
for cell in M.sortedCells:
for cell in self.sortedCells:
i, j, v = cell.faceIndex
I, J, V = np.r_[I,i], np.r_[J,j], np.r_[V,v]
VOL = self.vol
D = sp.csr_matrix((V,(I,J)), shape=(M.nC, M.nF))
D = sp.csr_matrix((V,(I,J)), shape=(self.nC, self.nF))
S = self.area
self._faceDiv = Utils.sdiag(1/VOL)*D*Utils.sdiag(S)
return self._faceDiv
+42 -11
View File
@@ -1,3 +1,4 @@
from SimPEG.Mesh import TensorMesh
from SimPEG.Mesh.TreeMesh import TreeMesh, TreeFace, TreeCell
import numpy as np
import unittest
@@ -13,6 +14,21 @@ class TestOcTreeObjects(unittest.TestCase):
self.Mr.children[0,0,0].refine()
self.Mr.number()
def q(s):
if s[0] == 'M':
m = self.M
s = s[1:]
else:
m = self.Mr
c = m.sortedCells[int(s[1])]
if len(s) == 2: return c
if s[2] == 'f' and len(s) == 5: return c.faces[s[2:]]
if s[2] == 'f': return c.faces[s[2:5]].edges[s[5:]]
if s[2] == 'e': return c.edges[s[2:]]
self.q = q
def test_counts(self):
self.assertTrue(self.M.nC == 2)
self.assertTrue(self.M.nFx == 3)
@@ -41,6 +57,18 @@ class TestOcTreeObjects(unittest.TestCase):
self.assertTrue(self.Mr.nEy == 20)
self.assertTrue(self.Mr.nEz == 20)
def test_sizes(self):
q = self.q
for key in ['Mc0','Mc1']:
self.assertTrue(q(key).vol == 0.5)
self.assertTrue(q(key+'fXm').area == 1.)
self.assertTrue(q(key+'fXp').area == 1.)
self.assertTrue(q(key+'fYm').area == 0.5)
self.assertTrue(q(key+'fYp').area == 0.5)
self.assertTrue(q(key+'fZm').area == 0.5)
self.assertTrue(q(key+'fZp').area == 0.5)
def test_pointersM(self):
c0 = self.M.children[0,0,0]
c0fXm = c0.faces['fXm']
@@ -118,17 +146,7 @@ class TestOcTreeObjects(unittest.TestCase):
cell.plotGrid(ax)
# plt.show()
def q(s):
if s[0] == 'M':
m = self.M
s = s[1:]
else:
m = self.Mr
c = m.sortedCells[int(s[1])]
if len(s) == 2: return c
if s[2] == 'f' and len(s) == 5: return c.faces[s[2:]]
if s[2] == 'f': return c.faces[s[2:5]].edges[s[5:]]
if s[2] == 'e': return c.edges[s[2:]]
q = self.q
c0 = self.Mr.sortedCells[0]
c0fXm = c0.faces['fXm']
@@ -465,6 +483,19 @@ class TestQuadTreeMesh(unittest.TestCase):
self.assertTrue(np.linalg.norm((np.c_[x,y]-self.M.gridEy).flatten()) == 0)
class SimpleOctreeOperatorTests(unittest.TestCase):
def setUp(self):
h1 = np.random.rand(5)
h2 = np.random.rand(7)
h3 = np.random.rand(3)
self.tM = TensorMesh([h1,h2,1])
self.oM = TreeMesh([h1,h2,1])
def test_faceDiv(self):
print (self.tM.faceDiv - self.oM.faceDiv)
self.assertTrue((self.tM.faceDiv - self.oM.faceDiv).toarray().sum() == 0)
if __name__ == '__main__':
unittest.main()