mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-29 16:16:18 +08:00
initial solver work: import SimPEG.Solver
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import scipy.sparse.linalg as linalg
|
||||
|
||||
|
||||
class Solver(object):
|
||||
"""docstring for Solver"""
|
||||
def __init__(self, A, doDirect=True, flag=None, options={}):
|
||||
|
||||
assert type(doDirect) is bool, 'doDirect must be a boolean'
|
||||
assert flag in [None, 'L', 'U', 'D'], "flag must be set to None, 'L', 'U', or 'D'"
|
||||
|
||||
self.A = A
|
||||
|
||||
self.dsolve = None
|
||||
self.doDirect = doDirect
|
||||
self.flag = flag
|
||||
self.options = options
|
||||
|
||||
def solve(self, b):
|
||||
if self.flag is None and self.doDirect:
|
||||
return self.solveDirect(b, **self.options)
|
||||
elif self.flag is None and not self.doDirect:
|
||||
return self.solveIter(b, **self.options)
|
||||
elif self.flag == 'U':
|
||||
return self.solveBackward(b)
|
||||
elif self.flag == 'L':
|
||||
return self.solveForward(b)
|
||||
elif self.flag == 'D':
|
||||
return self.solveDiagonal(b)
|
||||
else:
|
||||
raise Exception('Unknown flag.')
|
||||
pass
|
||||
|
||||
def clean(self):
|
||||
"""Cleans up the memory"""
|
||||
del self.dsolve
|
||||
self.dsolve = None
|
||||
|
||||
def solveDirect(self, b, backend='scipy'):
|
||||
assert np.shape(self.A)[1] == np.shape(b)[0], 'Dimension mismatch'
|
||||
|
||||
if self.dsolve is None:
|
||||
self.A = self.A.tocsc() # for efficiency
|
||||
self.dsolve = linalg.factorized(self.A)
|
||||
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
return self.dsolve(b)
|
||||
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
X[:,i] = self.dsolve(b[:,i])
|
||||
|
||||
return X
|
||||
|
||||
def solveIter(self, b, M=None, iterSolver='CG'):
|
||||
pass
|
||||
|
||||
def solveBackward(self, b):
|
||||
pass
|
||||
|
||||
def solveForward(self, b):
|
||||
pass
|
||||
|
||||
def solveDiagonal(self, b):
|
||||
diagA = self.A.diagonal()
|
||||
if len(b.shape) == 1 or b.shape[1] == 1:
|
||||
# Just one RHS
|
||||
return b/diagA
|
||||
# Multiple RHSs
|
||||
X = np.empty_like(b)
|
||||
for i in range(b.shape[1]):
|
||||
X[:,i] = b[:,i]/diagA
|
||||
return X
|
||||
@@ -2,3 +2,4 @@ from TensorMesh import TensorMesh
|
||||
from LogicallyOrthogonalMesh import LogicallyOrthogonalMesh
|
||||
import utils
|
||||
import inverse
|
||||
from Solver import Solver
|
||||
|
||||
@@ -83,6 +83,16 @@ class Problem(object):
|
||||
def dobs(self, value):
|
||||
self._dobs = value
|
||||
|
||||
def evalFunction(self, m, doDerivative=True):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:param bool doDerivative: do you want to compute the derivative?
|
||||
:rtype: numpy.array
|
||||
:return: Jv
|
||||
"""
|
||||
f = self.misfit(m)
|
||||
|
||||
return f, g, H
|
||||
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user