mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-29 01:59:48 +08:00
ecc7aaf9b0
bug fix in dependentProperty
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import numpy as np
|
|
from matutils import mkvc
|
|
import warnings
|
|
|
|
def DSolverWrap(fun, factorize=True, checkAccuracy=True, accuracyTol=1e-6):
|
|
|
|
def __init__(self, A, **kwargs):
|
|
self.A = A.tocsc()
|
|
self.kwargs = kwargs
|
|
if factorize:
|
|
self.solver = fun(self.A, **kwargs)
|
|
|
|
def solve(self, b):
|
|
if len(b.shape) == 1 or b.shape[1] == 1:
|
|
b = b.flatten()
|
|
# Just one RHS
|
|
if factorize:
|
|
X = self.solver.solve(b, **self.kwargs)
|
|
else:
|
|
X = fun(self.A, b, **self.kwargs)
|
|
else: # Multiple RHSs
|
|
X = np.empty_like(b)
|
|
for i in range(b.shape[1]):
|
|
if factorize:
|
|
X[:,i] = self.solver.solve(b[:,i])
|
|
else:
|
|
X[:,i] = fun(self.A, b[:,i], **self.kwargs)
|
|
|
|
if checkAccuracy:
|
|
nrm = np.linalg.norm(mkvc(self.A*X - b)) / np.linalg.norm(mkvc(b))
|
|
if nrm > accuracyTol:
|
|
msg = '### SolverWarning ###: Accuracy on solve is above tolerance: %e > %e' % (nrm, accuracyTol)
|
|
print msg
|
|
warnings.warn(msg, RuntimeWarning)
|
|
return X
|
|
|
|
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve})
|
|
|
|
|
|
|
|
def ISolverWrap(fun, checkAccuracy=True, accuracyTol=1e-5):
|
|
|
|
def __init__(self, A, **kwargs):
|
|
self.A = A.tocsc()
|
|
self.kwargs = kwargs
|
|
|
|
def solve(self, b):
|
|
if len(b.shape) == 1 or b.shape[1] == 1:
|
|
b = b.flatten()
|
|
# Just one RHS
|
|
out = fun(self.A, b, **self.kwargs)
|
|
if type(out) is tuple and len(out) == 2:
|
|
# We are dealing with scipy output with an info!
|
|
X = out[0]
|
|
self.info = out[1]
|
|
else:
|
|
X = out
|
|
else: # Multiple RHSs
|
|
X = np.empty_like(b)
|
|
for i in range(b.shape[1]):
|
|
out = fun(self.A, b[:,i], **self.kwargs)
|
|
if type(out) is tuple and len(out) == 2:
|
|
# We are dealing with scipy output with an info!
|
|
X[:,i] = out[0]
|
|
self.info = out[1]
|
|
else:
|
|
X[:,i] = out
|
|
|
|
if checkAccuracy:
|
|
nrm = np.linalg.norm(mkvc(self.A*X - b)) / np.linalg.norm(mkvc(b))
|
|
if nrm > accuracyTol:
|
|
msg = '### SolverWarning ###: Accuracy on solve is above tolerance: %e > %e' % (nrm, accuracyTol)
|
|
print msg
|
|
warnings.warn(msg, RuntimeWarning)
|
|
return X
|
|
|
|
return type(fun.__name__, (object,), {"__init__": __init__, "solve": solve})
|