From dcb9b8787d2e22e6c67ecf76195914e97367c232 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Wed, 18 Nov 2015 17:15:26 -0800 Subject: [PATCH] Corsen Trees --- SimPEG/Mesh/TreeMesh.py | 114 +++++++++++++++++++++++------------- tests/mesh/test_TreeMesh.py | 62 +++++++++++++++++++- 2 files changed, 132 insertions(+), 44 deletions(-) diff --git a/SimPEG/Mesh/TreeMesh.py b/SimPEG/Mesh/TreeMesh.py index 28a892f3..c510a842 100644 --- a/SimPEG/Mesh/TreeMesh.py +++ b/SimPEG/Mesh/TreeMesh.py @@ -411,6 +411,10 @@ class TreeMesh(BaseTensorMesh, InnerProducts): def refine(self, function=None, recursive=True, cells=None, balance=True, verbose=False, _inRecursion=False): + if type(function) in [int, long]: + level = function + function = lambda cell: level + if not _inRecursion: self.__dirty__ = True if verbose: print 'Refining Mesh' @@ -421,65 +425,86 @@ class TreeMesh(BaseTensorMesh, InnerProducts): for cell in cells: p = self._pointer(cell) if p[-1] >= self.levels: continue - do = function(Cell(self, cell, p)) > p[-1] + result = function(Cell(self, cell, p)) + if type(result) is bool: + do = result + elif type(result) in [int,long]: + do = result > p[-1] + else: + raise Exception('You must tell the program what to refine. Use BOOL or INT (level)') if do: recurse += self._refineCell(cell, p) if verbose: print ' ', time.time() - tic if recursive and len(recurse) > 0: - recurse += self.refine(function=function, recursive=True, cells=recurse, balance=balance, _inRecursion=True) + recurse += self.refine(function=function, recursive=True, cells=recurse, balance=balance, verbose=verbose, _inRecursion=True) if balance and not _inRecursion: self.balance() return recurse + def corsen(self, function=None, recursive=True, cells=None, balance=True, verbose=False, _inRecursion=False): + + if type(function) in [int, long]: + level = function + function = lambda cell: level + + if not _inRecursion: + self.__dirty__ = True + if verbose: print 'Corsening Mesh' + + cells = cells if cells is not None else sorted(self._cells) + recurse = [] + tic = time.time() + for cell in cells: + if cell not in self._cells: continue # already removed + p = self._pointer(cell) + if p[-1] >= self.levels: continue + result = function(Cell(self, cell, p)) + if type(result) is bool: + do = result + elif type(result) in [int,long]: + do = result < p[-1] + else: + raise Exception('You must tell the program what to corsen. Use BOOL or INT (level)') + if do: + recurse += self._corsenCell(cell, p) + + if verbose: print ' ', time.time() - tic + + if recursive and len(recurse) > 0: + recurse += self.corsen(function=function, recursive=True, cells=recurse, balance=balance, verbose=verbose, _inRecursion=True) + + if balance and not _inRecursion: + self.balance() + return recurse + + if verbose: print ' ', time.time() - tic + def _refineCell(self, ind, pointer=None): ind = self._asIndex(ind) pointer = self._asPointer(pointer if pointer is not None else ind) - assert ind in self - h = self._levelWidth(pointer[-1])/2 # halfWidth - nL = pointer[-1] + 1 # new level - add = lambda p:p[0]+p[1] - added = [] - def addCell(p): - i = self._index(p+[nL]) - self._cells.add(i) - added.append(i) - - addCell(map(add, zip(pointer[:-1], [0,0,0][:self.dim]))) - addCell(map(add, zip(pointer[:-1], [h,0,0][:self.dim]))) - addCell(map(add, zip(pointer[:-1], [0,h,0][:self.dim]))) - addCell(map(add, zip(pointer[:-1], [h,h,0][:self.dim]))) - if self.dim == 3: - addCell(map(add, zip(pointer[:-1], [0,0,h]))) - addCell(map(add, zip(pointer[:-1], [h,0,h]))) - addCell(map(add, zip(pointer[:-1], [0,h,h]))) - addCell(map(add, zip(pointer[:-1], [h,h,h]))) + if ind not in self: + raise CellLookUpException(ind) + children = self._childPointers(pointer, returnAll=True) + for child in children: + self._cells.add(self._asIndex(child)) self._cells.remove(ind) - return added + return [self._asIndex(child) for child in children] - def corsen(self, function=None): - self.__dirty__ = True - raise Exception('Not yet implemented') - - - def _corsenCell(self, pointer): - raise Exception('Not yet implemented') - - # something like this: ?? - pointer = self._asPointer(pointer) - ind = self._asIndex(pointer) - assert ind in self - - parent = self._parentPointer(ind) - children = _childPointers(parent) + def _corsenCell(self, ind, pointer=None): + ind = self._asIndex(ind) + pointer = self._asPointer(pointer if pointer is not None else ind) + if ind not in self: + raise CellLookUpException(ind) + parent = self._parentPointer(pointer) + children = self._childPointers(parent, returnAll=True) for child in children: self._cells.remove(self._asIndex(child)) - parentInd = self._asIndex(parent) self._cells.add(parentInd) - return parentInd + return [parentInd] def _asPointer(self, ind): if type(ind) in [int, long]: @@ -2225,7 +2250,12 @@ def SortGrid(grid, offset=0): return sorted(range(offset,grid.shape[0]+offset), key=K) -class NotBalancedException(Exception): + +class TreeException(Exception): + pass +class NotBalancedException(TreeException): + pass +class CellLookUpException(TreeException): pass if __name__ == '__main__': @@ -2250,9 +2280,9 @@ if __name__ == '__main__': return 2 # T = TreeMesh([[(1,128)],[(1,128)],[(1,128)]],levels=7) - T = TreeMesh([128,128,128]) + # T = TreeMesh([128,128,128]) # T = TreeMesh([64,64],levels=6) - # T = TreeMesh([4,4,4],levels=2) + T = TreeMesh([4,4,4]) # T = TreeMesh([[(1,128)],[(1,128)]],levels=7) # T.refine(lambda xc:2, balance=False) # T._index([0,0,0]) diff --git a/tests/mesh/test_TreeMesh.py b/tests/mesh/test_TreeMesh.py index 61c2f9be..e624ce87 100644 --- a/tests/mesh/test_TreeMesh.py +++ b/tests/mesh/test_TreeMesh.py @@ -1,12 +1,11 @@ from SimPEG import Mesh, Tests +from SimPEG.Mesh.TreeMesh import CellLookUpException import numpy as np import matplotlib.pyplot as plt import unittest TOL = 1e-8 - - class TestSimpleQuadTree(unittest.TestCase): def test_counts(self): @@ -19,6 +18,7 @@ class TestSimpleQuadTree(unittest.TestCase): M._refineCell([0,0,1]) M.number() # M.plotGrid(showIt=True) + print M assert M.nhFx == 2 assert M.nFx == 9 @@ -26,6 +26,64 @@ class TestSimpleQuadTree(unittest.TestCase): assert np.allclose(np.r_[M._areaFxFull, M._areaFyFull], M._deflationMatrix('F') * M.area) + def test_refine(self): + M = Mesh.TreeMesh([4,4,4]) + M.refine(1) + assert M.nC == 8 + M.refine(0) + assert M.nC == 8 + M.corsen(0) + assert M.nC == 1 + + def test_corsen(self): + nc = 8 + h1 = np.random.rand(nc)*nc*0.5 + nc*0.5 + h2 = np.random.rand(nc)*nc*0.5 + nc*0.5 + h = [hi/np.sum(hi) for hi in [h1, h2]] # normalize + M = Mesh.TreeMesh(h) + M._refineCell([0,0,0]) + M._refineCell([0,0,1]) + self.assertRaises(CellLookUpException, M._refineCell, [0,0,1]) + assert M._index([0,0,1]) not in M + assert M._index([0,0,2]) in M + assert M._index([2,0,2]) in M + assert M._index([0,2,2]) in M + assert M._index([2,2,2]) in M + + self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1]) + M._corsenCell([0,0,2]) + assert M._index([0,0,1]) in M + assert M._index([0,0,2]) not in M + assert M._index([2,0,2]) not in M + assert M._index([0,2,2]) not in M + assert M._index([2,2,2]) not in M + M._refineCell([0,0,1]) + + self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1]) + M._corsenCell([2,0,2]) + assert M._index([0,0,1]) in M + assert M._index([0,0,2]) not in M + assert M._index([2,0,2]) not in M + assert M._index([0,2,2]) not in M + assert M._index([2,2,2]) not in M + M._refineCell([0,0,1]) + + self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1]) + M._corsenCell([0,2,2]) + assert M._index([0,0,1]) in M + assert M._index([0,0,2]) not in M + assert M._index([2,0,2]) not in M + assert M._index([0,2,2]) not in M + assert M._index([2,2,2]) not in M + M._refineCell([0,0,1]) + + self.assertRaises(CellLookUpException, M._corsenCell, [0,0,1]) + M._corsenCell([2,2,2]) + assert M._index([0,0,1]) in M + assert M._index([0,0,2]) not in M + assert M._index([2,0,2]) not in M + assert M._index([0,2,2]) not in M + assert M._index([2,2,2]) not in M def test_faceDiv(self):