update solver utils.

This commit is contained in:
rowanc1
2014-05-16 18:46:31 -07:00
parent f3170a8a7f
commit 26340fe90f
2 changed files with 9 additions and 5 deletions
+3 -2
View File
@@ -31,6 +31,7 @@ def dotest(solver, multi=False, **solverOpts):
e = np.ones((M.nC, numRHS))
rhs = A * e
x = Ainv * rhs
Ainv.clean()
return np.linalg.norm(e-x,np.inf) < TOL
class TestSolver(unittest.TestCase):
@@ -41,8 +42,8 @@ class TestSolver(unittest.TestCase):
def test_direct_splu_1(self): self.assertTrue(dotest(SolverLU, False))
def test_direct_splu_M(self): self.assertTrue(dotest(SolverLU, True))
def test_iterative_cg_1(self): self.assertTrue(dotest(SolverLU, False))
def test_iterative_cg_M(self): self.assertTrue(dotest(SolverLU, True))
def test_iterative_cg_1(self): self.assertTrue(dotest(SolverCG, False))
def test_iterative_cg_M(self): self.assertTrue(dotest(SolverCG, True))
if __name__ == '__main__':
+6 -3
View File
@@ -42,10 +42,9 @@ def DSolverWrap(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6):
return X
def clean(self):
if hasattr(self.solver, 'clean'):
if factorize and hasattr(self.solver, 'clean'):
return self.solver.clean()
def __mul__(self, val):
if type(val) is np.ndarray:
return self.solve(val)
@@ -87,9 +86,13 @@ def ISolverWrap(fun, checkAccuracy=True, accuracyTol=1e-5):
_checkAccuracy(self.A, b, X, accuracyTol)
return X
def clean(self):
if hasattr(self.solver, 'clean'):
return self.solver.clean()
def __mul__(self, val):
if type(val) is np.ndarray:
return self.solve(val)
raise TypeError('Can only multiply by a numpy array.')
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve, "__mul__": __mul__})
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve, "clean": clean, "__mul__": __mul__})