diff --git a/SimPEG/Tests/test_Solver.py b/SimPEG/Tests/test_Solver.py index 9135e851..47797dcb 100644 --- a/SimPEG/Tests/test_Solver.py +++ b/SimPEG/Tests/test_Solver.py @@ -31,6 +31,7 @@ def dotest(solver, multi=False, **solverOpts): e = np.ones((M.nC, numRHS)) rhs = A * e x = Ainv * rhs + Ainv.clean() return np.linalg.norm(e-x,np.inf) < TOL class TestSolver(unittest.TestCase): @@ -41,8 +42,8 @@ class TestSolver(unittest.TestCase): def test_direct_splu_1(self): self.assertTrue(dotest(SolverLU, False)) def test_direct_splu_M(self): self.assertTrue(dotest(SolverLU, True)) - def test_iterative_cg_1(self): self.assertTrue(dotest(SolverLU, False)) - def test_iterative_cg_M(self): self.assertTrue(dotest(SolverLU, True)) + def test_iterative_cg_1(self): self.assertTrue(dotest(SolverCG, False)) + def test_iterative_cg_M(self): self.assertTrue(dotest(SolverCG, True)) if __name__ == '__main__': diff --git a/SimPEG/Utils/SolverUtils.py b/SimPEG/Utils/SolverUtils.py index af721d82..c2f17b6d 100644 --- a/SimPEG/Utils/SolverUtils.py +++ b/SimPEG/Utils/SolverUtils.py @@ -42,10 +42,9 @@ def DSolverWrap(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6): return X def clean(self): - if hasattr(self.solver, 'clean'): + if factorize and hasattr(self.solver, 'clean'): return self.solver.clean() - def __mul__(self, val): if type(val) is np.ndarray: return self.solve(val) @@ -87,9 +86,13 @@ def ISolverWrap(fun, checkAccuracy=True, accuracyTol=1e-5): _checkAccuracy(self.A, b, X, accuracyTol) return X + def clean(self): + if hasattr(self.solver, 'clean'): + return self.solver.clean() + def __mul__(self, val): if type(val) is np.ndarray: return self.solve(val) raise TypeError('Can only multiply by a numpy array.') - return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve, "__mul__": __mul__}) + return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve, "clean": clean, "__mul__": __mul__})