diff --git a/SimPEG/utils/Solver.py b/SimPEG/utils/Solver.py index 16db872f..5861c146 100644 --- a/SimPEG/utils/Solver.py +++ b/SimPEG/utils/Solver.py @@ -19,6 +19,11 @@ except Exception, e: DEFAULTS['forward'] = 'python' DEFAULTS['backward'] = 'python' +try: + import mumps +except Exception, e: + print 'Warning: mumps solver not available.' + class Solver(object): """ Solver is a light wrapper on the various types of @@ -113,6 +118,9 @@ class Solver(object): def clean(self): """Cleans up the memory""" + if self.options.has_key('backend'): + if self.options['backend'] == 'mumps': + self.mctx.destroy() del self.dsolve self.dsolve = None @@ -120,6 +128,7 @@ class Solver(object): """ Use solve instead of this interface. + :param numpy.ndarray b: the right hand side :param bool factorize: if you want to factorize and store factors :param str backend: which backend to use. Default is scipy :rtype: numpy.ndarray @@ -129,6 +138,22 @@ class Solver(object): assert np.shape(self.A)[1] == np.shape(b)[0], 'Dimension mismatch' + if backend == 'scipy': + X = self.solveDirect_scipy(b, factorize) + elif backend == 'mumps': + X = self.solveDirect_mumps(b, factorize) + + return X + + def solveDirect_scipy(self, b, factorize): + """ + Use solve instead of this interface. + + :param numpy.ndarray b: the right hand side + :param bool factorize: if you want to factorize and store factors + :rtype: numpy.ndarray + :return: x + """ if factorize and self.dsolve is None: self.A = self.A.tocsc() # for efficiency self.dsolve = linalg.factorized(self.A) @@ -150,6 +175,48 @@ class Solver(object): return X + def solveDirect_mumps(self, b, factorize): + """ + Use solve instead of this interface. + + :param numpy.ndarray b: the right hand side + :param bool factorize: if you want to factorize and store factors + :rtype: numpy.ndarray + :return: x + """ + if factorize and self.dsolve is None: + self.mctx = mumps.DMumpsContext() + self.mctx.set_icntl(14, 60) + # self.mctx.set_silent() + self.mctx.set_centralized_sparse(self.A) + self.mctx.run(job=4) + + def mdsolve(rhs): + x = rhs.copy() + self.mctx.set_rhs(x) + self.mctx.run(job=3) + return x + + self.dsolve = mdsolve + + if len(b.shape) == 1 or b.shape[1] == 1: + # Just one RHS + if factorize: + X = self.dsolve(b) + else: + X = mumps.spsolve(self.A, b) + + else: + # Multiple RHSs + X = np.empty_like(b) + for i in range(b.shape[1]): + if factorize: + X[:,i] = self.dsolve(b[:,i]) + else: + X[:,i] = mumps.spsolve(self.A,b[:,i]) + + return X + def solveIter(self, b, backend=None, M=None, iterSolver='CG', tol=1e-6, maxIter=50): if backend is None: backend = DEFAULTS['iter']