mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-01 11:44:35 +08:00
Test solver
This commit is contained in:
+28
-21
@@ -5,26 +5,29 @@ from SimPEG.Utils import sdiag
|
||||
import numpy as np
|
||||
import scipy.sparse as sparse
|
||||
|
||||
TOL = 1e-10
|
||||
TOLD = 1e-10
|
||||
TOLI = 1e-3
|
||||
numRHS = 5
|
||||
|
||||
def dotest(solver, multi=False, **solverOpts):
|
||||
h1 = np.ones(10)*100.
|
||||
h2 = np.ones(10)*100.
|
||||
h3 = np.ones(10)*100.
|
||||
def dotest(MYSOLVER, multi=False, A=None, **solverOpts):
|
||||
if A is None:
|
||||
h1 = np.ones(10)*100.
|
||||
h2 = np.ones(10)*100.
|
||||
h3 = np.ones(10)*100.
|
||||
|
||||
h = [h1,h2,h3]
|
||||
h = [h1,h2,h3]
|
||||
|
||||
M = TensorMesh(h)
|
||||
M = TensorMesh(h)
|
||||
|
||||
D = M.faceDiv
|
||||
G = -M.faceDiv.T
|
||||
Msig = M.getFaceInnerProduct()
|
||||
A = D*Msig*G
|
||||
A[0,0] *= 10 # remove the constant null space from the matrix
|
||||
D = M.faceDiv
|
||||
G = -M.faceDiv.T
|
||||
Msig = M.getFaceInnerProduct()
|
||||
A = D*Msig*G
|
||||
A[-1,-1] *= 1/M.vol[-1] # remove the constant null space from the matrix
|
||||
else:
|
||||
M = Mesh.TensorMesh([A.shape[0]])
|
||||
|
||||
|
||||
Ainv = Solver(A, **solverOpts)
|
||||
Ainv = MYSOLVER(A, **solverOpts)
|
||||
if multi:
|
||||
e = np.ones(M.nC)
|
||||
else:
|
||||
@@ -32,18 +35,22 @@ def dotest(solver, multi=False, **solverOpts):
|
||||
rhs = A * e
|
||||
x = Ainv * rhs
|
||||
Ainv.clean()
|
||||
return np.linalg.norm(e-x,np.inf) < TOL
|
||||
return np.linalg.norm(e-x,np.inf)
|
||||
|
||||
class TestSolver(unittest.TestCase):
|
||||
|
||||
def test_direct_spsolve_1(self): self.assertTrue(dotest(Solver, False))
|
||||
def test_direct_spsolve_M(self): self.assertTrue(dotest(Solver, True))
|
||||
def test_direct_spsolve_1(self): self.assertLess(dotest(Solver, False),TOLD)
|
||||
def test_direct_spsolve_M(self): self.assertLess(dotest(Solver, True),TOLD)
|
||||
|
||||
def test_direct_splu_1(self): self.assertTrue(dotest(SolverLU, False))
|
||||
def test_direct_splu_M(self): self.assertTrue(dotest(SolverLU, True))
|
||||
def test_direct_splu_1(self): self.assertLess(dotest(SolverLU, False),TOLD)
|
||||
def test_direct_splu_M(self): self.assertLess(dotest(SolverLU, True),TOLD)
|
||||
|
||||
def test_iterative_diag_1(self): self.assertLess(dotest(SolverDiag, False, A=Utils.sdiag(np.random.rand(10)+1.0)),TOLI)
|
||||
def test_iterative_diag_M(self): self.assertLess(dotest(SolverDiag, True, A=Utils.sdiag(np.random.rand(10)+1.0)),TOLI)
|
||||
|
||||
def test_iterative_cg_1(self): self.assertLess(dotest(SolverCG, False),TOLI)
|
||||
def test_iterative_cg_M(self): self.assertLess(dotest(SolverCG, True),TOLI)
|
||||
|
||||
def test_iterative_cg_1(self): self.assertTrue(dotest(SolverCG, False))
|
||||
def test_iterative_cg_M(self): self.assertTrue(dotest(SolverCG, True))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -66,18 +66,6 @@ class MapTests(unittest.TestCase):
|
||||
self.assertLess(np.linalg.norm(mod.transform - np.r_[1,1,2,2,10,10,10,10.]), TOL)
|
||||
self.assertTrue(mod.test())
|
||||
|
||||
def test_activeCells(self):
|
||||
M = Mesh.TensorMesh([2,4],'0C')
|
||||
actMap = Maps.ActiveCells(M, M.vectorCCy <=0, 10, nC=M.nCy)
|
||||
vertMap = Maps.Vertical1DMap(M)
|
||||
mod = Maps.Model(np.r_[1,2.],vertMap * actMap)
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.colorbar(M.plotImage(mod.transform)[0])
|
||||
# plt.show()
|
||||
self.assertLess(np.linalg.norm(mod.transform - np.r_[1,1,2,2,10,10,10,10.]), TOL)
|
||||
self.assertTrue(mod.test())
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -105,10 +105,9 @@ def SolverWrapI(fun, checkAccuracy=True, accuracyTol=1e-5):
|
||||
return X
|
||||
|
||||
def clean(self):
|
||||
if hasattr(self.solver, 'clean'):
|
||||
return self.solver.clean()
|
||||
pass
|
||||
|
||||
return type(fun.__name__, (object,), {"__init__": __init__, "clean": clean, "__mul__": __mul__})
|
||||
return type(fun.__name__+'_Wrapped', (object,), {"__init__": __init__, "clean": clean, "__mul__": __mul__})
|
||||
|
||||
|
||||
Solver = SolverWrapD(sp.linalg.spsolve, factorize=False)
|
||||
|
||||
Reference in New Issue
Block a user