diff --git a/SimPEG/Solver.py b/SimPEG/Solver.py new file mode 100644 index 00000000..2087ef5f --- /dev/null +++ b/SimPEG/Solver.py @@ -0,0 +1,75 @@ +import numpy as np +import scipy.sparse.linalg as linalg + + +class Solver(object): + """docstring for Solver""" + def __init__(self, A, doDirect=True, flag=None, options={}): + + assert type(doDirect) is bool, 'doDirect must be a boolean' + assert flag in [None, 'L', 'U', 'D'], "flag must be set to None, 'L', 'U', or 'D'" + + self.A = A + + self.dsolve = None + self.doDirect = doDirect + self.flag = flag + self.options = options + + def solve(self, b): + if self.flag is None and self.doDirect: + return self.solveDirect(b, **self.options) + elif self.flag is None and not self.doDirect: + return self.solveIter(b, **self.options) + elif self.flag == 'U': + return self.solveBackward(b) + elif self.flag == 'L': + return self.solveForward(b) + elif self.flag == 'D': + return self.solveDiagonal(b) + else: + raise Exception('Unknown flag.') + pass + + def clean(self): + """Cleans up the memory""" + del self.dsolve + self.dsolve = None + + def solveDirect(self, b, backend='scipy'): + assert np.shape(self.A)[1] == np.shape(b)[0], 'Dimension mismatch' + + if self.dsolve is None: + self.A = self.A.tocsc() # for efficiency + self.dsolve = linalg.factorized(self.A) + + if len(b.shape) == 1 or b.shape[1] == 1: + # Just one RHS + return self.dsolve(b) + + # Multiple RHSs + X = np.empty_like(b) + for i in range(b.shape[1]): + X[:,i] = self.dsolve(b[:,i]) + + return X + + def solveIter(self, b, M=None, iterSolver='CG'): + pass + + def solveBackward(self, b): + pass + + def solveForward(self, b): + pass + + def solveDiagonal(self, b): + diagA = self.A.diagonal() + if len(b.shape) == 1 or b.shape[1] == 1: + # Just one RHS + return b/diagA + # Multiple RHSs + X = np.empty_like(b) + for i in range(b.shape[1]): + X[:,i] = b[:,i]/diagA + return X diff --git a/SimPEG/__init__.py b/SimPEG/__init__.py index f36000a7..ea9faf33 100644 --- a/SimPEG/__init__.py +++ b/SimPEG/__init__.py @@ -2,3 +2,4 @@ from TensorMesh import TensorMesh from LogicallyOrthogonalMesh import LogicallyOrthogonalMesh import utils import inverse +from Solver import Solver diff --git a/SimPEG/forward/Problem.py b/SimPEG/forward/Problem.py index 1558ecf0..5b716f1f 100644 --- a/SimPEG/forward/Problem.py +++ b/SimPEG/forward/Problem.py @@ -83,6 +83,16 @@ class Problem(object): def dobs(self, value): self._dobs = value + def evalFunction(self, m, doDerivative=True): + """ + :param numpy.array m: model + :param bool doDerivative: do you want to compute the derivative? + :rtype: numpy.array + :return: Jv + """ + f = self.misfit(m) + + return f, g, H def J(self, m, v, u=None): """