Corsen Trees

This commit is contained in:
Rowan Cockett
2015-11-18 17:15:26 -08:00
parent ca05d0599e
commit dcb9b8787d
2 changed files with 132 additions and 44 deletions
+72 -42
View File
@@ -411,6 +411,10 @@ class TreeMesh(BaseTensorMesh, InnerProducts):
def refine(self, function=None, recursive=True, cells=None, balance=True, verbose=False, _inRecursion=False):
if type(function) in [int, long]:
level = function
function = lambda cell: level
if not _inRecursion:
self.__dirty__ = True
if verbose: print 'Refining Mesh'
@@ -421,65 +425,86 @@ class TreeMesh(BaseTensorMesh, InnerProducts):
for cell in cells:
p = self._pointer(cell)
if p[-1] >= self.levels: continue
do = function(Cell(self, cell, p)) > p[-1]
result = function(Cell(self, cell, p))
if type(result) is bool:
do = result
elif type(result) in [int,long]:
do = result > p[-1]
else:
raise Exception('You must tell the program what to refine. Use BOOL or INT (level)')
if do:
recurse += self._refineCell(cell, p)
if verbose: print ' ', time.time() - tic
if recursive and len(recurse) > 0:
recurse += self.refine(function=function, recursive=True, cells=recurse, balance=balance, _inRecursion=True)
recurse += self.refine(function=function, recursive=True, cells=recurse, balance=balance, verbose=verbose, _inRecursion=True)
if balance and not _inRecursion:
self.balance()
return recurse
def corsen(self, function=None, recursive=True, cells=None, balance=True, verbose=False, _inRecursion=False):
if type(function) in [int, long]:
level = function
function = lambda cell: level
if not _inRecursion:
self.__dirty__ = True
if verbose: print 'Corsening Mesh'
cells = cells if cells is not None else sorted(self._cells)
recurse = []
tic = time.time()
for cell in cells:
if cell not in self._cells: continue # already removed
p = self._pointer(cell)
if p[-1] >= self.levels: continue
result = function(Cell(self, cell, p))
if type(result) is bool:
do = result
elif type(result) in [int,long]:
do = result < p[-1]
else:
raise Exception('You must tell the program what to corsen. Use BOOL or INT (level)')
if do:
recurse += self._corsenCell(cell, p)
if verbose: print ' ', time.time() - tic
if recursive and len(recurse) > 0:
recurse += self.corsen(function=function, recursive=True, cells=recurse, balance=balance, verbose=verbose, _inRecursion=True)
if balance and not _inRecursion:
self.balance()
return recurse
if verbose: print ' ', time.time() - tic
def _refineCell(self, ind, pointer=None):
ind = self._asIndex(ind)
pointer = self._asPointer(pointer if pointer is not None else ind)
assert ind in self
h = self._levelWidth(pointer[-1])/2 # halfWidth
nL = pointer[-1] + 1 # new level
add = lambda p:p[0]+p[1]
added = []
def addCell(p):
i = self._index(p+[nL])
self._cells.add(i)
added.append(i)
addCell(map(add, zip(pointer[:-1], [0,0,0][:self.dim])))
addCell(map(add, zip(pointer[:-1], [h,0,0][:self.dim])))
addCell(map(add, zip(pointer[:-1], [0,h,0][:self.dim])))
addCell(map(add, zip(pointer[:-1], [h,h,0][:self.dim])))
if self.dim == 3:
addCell(map(add, zip(pointer[:-1], [0,0,h])))
addCell(map(add, zip(pointer[:-1], [h,0,h])))
addCell(map(add, zip(pointer[:-1], [0,h,h])))
addCell(map(add, zip(pointer[:-1], [h,h,h])))
if ind not in self:
raise CellLookUpException(ind)
children = self._childPointers(pointer, returnAll=True)
for child in children:
self._cells.add(self._asIndex(child))
self._cells.remove(ind)
return added
return [self._asIndex(child) for child in children]
def corsen(self, function=None):
self.__dirty__ = True
raise Exception('Not yet implemented')
def _corsenCell(self, pointer):
raise Exception('Not yet implemented')
# something like this: ??
pointer = self._asPointer(pointer)
ind = self._asIndex(pointer)
assert ind in self
parent = self._parentPointer(ind)
children = _childPointers(parent)
def _corsenCell(self, ind, pointer=None):
ind = self._asIndex(ind)
pointer = self._asPointer(pointer if pointer is not None else ind)
if ind not in self:
raise CellLookUpException(ind)
parent = self._parentPointer(pointer)
children = self._childPointers(parent, returnAll=True)
for child in children:
self._cells.remove(self._asIndex(child))
parentInd = self._asIndex(parent)
self._cells.add(parentInd)
return parentInd
return [parentInd]
def _asPointer(self, ind):
if type(ind) in [int, long]:
@@ -2225,7 +2250,12 @@ def SortGrid(grid, offset=0):
return sorted(range(offset,grid.shape[0]+offset), key=K)
class NotBalancedException(Exception):
class TreeException(Exception):
pass
class NotBalancedException(TreeException):
pass
class CellLookUpException(TreeException):
pass
if __name__ == '__main__':
@@ -2250,9 +2280,9 @@ if __name__ == '__main__':
return 2
# T = TreeMesh([[(1,128)],[(1,128)],[(1,128)]],levels=7)
T = TreeMesh([128,128,128])
# T = TreeMesh([128,128,128])
# T = TreeMesh([64,64],levels=6)
# T = TreeMesh([4,4,4],levels=2)
T = TreeMesh([4,4,4])
# T = TreeMesh([[(1,128)],[(1,128)]],levels=7)
# T.refine(lambda xc:2, balance=False)
# T._index([0,0,0])
+60 -2
View File
@@ -1,12 +1,11 @@
from SimPEG import Mesh, Tests
from SimPEG.Mesh.TreeMesh import CellLookUpException
import numpy as np
import matplotlib.pyplot as plt
import unittest
TOL = 1e-8
class TestSimpleQuadTree(unittest.TestCase):
def test_counts(self):
@@ -19,6 +18,7 @@ class TestSimpleQuadTree(unittest.TestCase):
M._refineCell([0,0,1])
M.number()
# M.plotGrid(showIt=True)
print M
assert M.nhFx == 2
assert M.nFx == 9
@@ -26,6 +26,64 @@ class TestSimpleQuadTree(unittest.TestCase):
assert np.allclose(np.r_[M._areaFxFull, M._areaFyFull], M._deflationMatrix('F') * M.area)
def test_refine(self):
M = Mesh.TreeMesh([4,4,4])
M.refine(1)
assert M.nC == 8
M.refine(0)
assert M.nC == 8
M.corsen(0)
assert M.nC == 1
def test_corsen(self):
nc = 8
h1 = np.random.rand(nc)*nc*0.5 + nc*0.5
h2 = np.random.rand(nc)*nc*0.5 + nc*0.5
h = [hi/np.sum(hi) for hi in [h1, h2]] # normalize
M = Mesh.TreeMesh(h)
M._refineCell([0,0,0])
M._refineCell([0,0,1])
self.assertRaises(CellLookUpException, M._refineCell, [0,0,1])
assert M._index([0,0,1]) not in M
assert M._index([0,0,2]) in M
assert M._index([2,0,2]) in M
assert M._index([0,2,2]) in M
assert M._index([2,2,2]) in M
self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1])
M._corsenCell([0,0,2])
assert M._index([0,0,1]) in M
assert M._index([0,0,2]) not in M
assert M._index([2,0,2]) not in M
assert M._index([0,2,2]) not in M
assert M._index([2,2,2]) not in M
M._refineCell([0,0,1])
self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1])
M._corsenCell([2,0,2])
assert M._index([0,0,1]) in M
assert M._index([0,0,2]) not in M
assert M._index([2,0,2]) not in M
assert M._index([0,2,2]) not in M
assert M._index([2,2,2]) not in M
M._refineCell([0,0,1])
self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1])
M._corsenCell([0,2,2])
assert M._index([0,0,1]) in M
assert M._index([0,0,2]) not in M
assert M._index([2,0,2]) not in M
assert M._index([0,2,2]) not in M
assert M._index([2,2,2]) not in M
M._refineCell([0,0,1])
self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1])
M._corsenCell([2,2,2])
assert M._index([0,0,1]) in M
assert M._index([0,0,2]) not in M
assert M._index([2,0,2]) not in M
assert M._index([0,2,2]) not in M
assert M._index([2,2,2]) not in M
def test_faceDiv(self):