mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-05 00:51:58 +08:00
MeatClasses to make every SimPEG object saveable to an hdf5 file.
cleaned up imports in a lot of places. Made solver not copy matrices around for GS preconditioning.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import utils
|
||||
from utils import Solver
|
||||
import mesh
|
||||
|
||||
@@ -9,6 +9,9 @@ import matplotlib.pyplot as plt
|
||||
class LinearProblem(Problem):
|
||||
"""docstring for LinearProblem"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Problem.__init__(self, *args, **kwargs)
|
||||
|
||||
def dpred(self, m, u=None):
|
||||
return self.G.dot(m)
|
||||
|
||||
|
||||
+10
-10
@@ -1,6 +1,4 @@
|
||||
import numpy as np
|
||||
from SimPEG.utils import mkvc, sdiag, count, timeIt
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import utils, np, sp
|
||||
norm = np.linalg.norm
|
||||
|
||||
|
||||
@@ -37,6 +35,8 @@ class Problem(object):
|
||||
to (locally) find how model parameters change the data, and optimize!
|
||||
"""
|
||||
|
||||
__metaclass__ = utils.Save.Savable
|
||||
|
||||
counter = None
|
||||
|
||||
def __init__(self, mesh):
|
||||
@@ -85,7 +85,7 @@ class Problem(object):
|
||||
def dobs(self, value):
|
||||
self._dobs = value
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def dpred(self, m, u=None):
|
||||
"""
|
||||
Predicted data.
|
||||
@@ -97,7 +97,7 @@ class Problem(object):
|
||||
u = self.field(m)
|
||||
return self.P*u
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def dataResidual(self, m, u=None):
|
||||
"""
|
||||
:param numpy.array m: geophysical model
|
||||
@@ -117,7 +117,7 @@ class Problem(object):
|
||||
|
||||
return self.dpred(m, u=u) - self.dobs
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -147,7 +147,7 @@ class Problem(object):
|
||||
"""
|
||||
raise NotImplementedError('J is not yet implemented.')
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def Jt(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -161,7 +161,7 @@ class Problem(object):
|
||||
raise NotImplementedError('Jt is not yet implemented.')
|
||||
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def J_approx(self, m, v, u=None):
|
||||
"""
|
||||
|
||||
@@ -176,7 +176,7 @@ class Problem(object):
|
||||
"""
|
||||
return self.J(m, v, u)
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def Jt_approx(self, m, v, u=None):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
@@ -241,7 +241,7 @@ class Problem(object):
|
||||
dobs = self.dpred(m,u=u)
|
||||
noise = std*abs(dobs)*np.random.randn(*dobs.shape)
|
||||
dobs = dobs+noise
|
||||
eps = np.linalg.norm(mkvc(dobs),2)*1e-5
|
||||
eps = np.linalg.norm(utils.mkvc(dobs),2)*1e-5
|
||||
Wd = 1/(abs(dobs)*std+eps)
|
||||
return dobs, Wd
|
||||
|
||||
|
||||
+19
-19
@@ -1,7 +1,5 @@
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import SimPEG
|
||||
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt, callHooks
|
||||
from SimPEG import utils, sp, np
|
||||
from Optimize import Remember
|
||||
from BetaSchedule import Cooling
|
||||
from SimPEG.inverse import IterationPrinters, StoppingCriteria
|
||||
@@ -9,6 +7,8 @@ from SimPEG.inverse import IterationPrinters, StoppingCriteria
|
||||
class BaseInversion(object):
|
||||
"""docstring for BaseInversion"""
|
||||
|
||||
__metaclass__ = utils.Save.Savable
|
||||
|
||||
maxIter = 1 #: Maximum number of iterations
|
||||
name = 'BaseInversion'
|
||||
|
||||
@@ -17,11 +17,11 @@ class BaseInversion(object):
|
||||
comment = '' #: Used by some functions to indicate what is going on in the algorithm
|
||||
counter = None #: Set this to a SimPEG.utils.Counter() if you want to count things
|
||||
|
||||
beta0 = None #: The initial Beta (regularization parameter)
|
||||
|
||||
beta0 = None #: The initial Beta (regularization parameter)
|
||||
beta0_ratio = 0.1 #: When beta0 is set to None, estimateBeta0 is used with this ratio
|
||||
|
||||
def __init__(self, prob, reg, opt, **kwargs):
|
||||
setKwargs(self, **kwargs)
|
||||
utils.setKwargs(self, **kwargs)
|
||||
self.prob = prob
|
||||
self.reg = reg
|
||||
self.opt = opt
|
||||
@@ -46,7 +46,7 @@ class BaseInversion(object):
|
||||
Standard deviation weighting matrix.
|
||||
"""
|
||||
if getattr(self,'_Wd',None) is None:
|
||||
eps = np.linalg.norm(mkvc(self.prob.dobs),2)*1e-5
|
||||
eps = np.linalg.norm(utils.mkvc(self.prob.dobs),2)*1e-5
|
||||
self._Wd = 1/(abs(self.prob.dobs)*self.prob.std+eps)
|
||||
return self._Wd
|
||||
@Wd.setter
|
||||
@@ -70,7 +70,7 @@ class BaseInversion(object):
|
||||
def phi_d_target(self, value):
|
||||
self._phi_d_target = value
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def run(self, m0):
|
||||
"""run(m0)
|
||||
|
||||
@@ -89,7 +89,7 @@ class BaseInversion(object):
|
||||
|
||||
return self.m
|
||||
|
||||
@callHooks('startup')
|
||||
@utils.callHooks('startup')
|
||||
def startup(self, m0):
|
||||
"""
|
||||
**startup** is called at the start of any new run call.
|
||||
@@ -109,7 +109,7 @@ class BaseInversion(object):
|
||||
self.phi_d_last = np.nan
|
||||
self.phi_m_last = np.nan
|
||||
|
||||
@callHooks('doStartIteration')
|
||||
@utils.callHooks('doStartIteration')
|
||||
def doStartIteration(self):
|
||||
"""
|
||||
**doStartIteration** is called at the end of each run iteration.
|
||||
@@ -120,7 +120,7 @@ class BaseInversion(object):
|
||||
self._beta = self.getBeta()
|
||||
|
||||
|
||||
@callHooks('doEndIteration')
|
||||
@utils.callHooks('doEndIteration')
|
||||
def doEndIteration(self):
|
||||
"""
|
||||
**doEndIteration** is called at the end of each run iteration.
|
||||
@@ -179,7 +179,7 @@ class BaseInversion(object):
|
||||
|
||||
def stoppingCriteria(self):
|
||||
if self.debug: print 'checking stoppingCriteria'
|
||||
return checkStoppers(self, self.stoppers)
|
||||
return utils.checkStoppers(self, self.stoppers)
|
||||
|
||||
|
||||
def printDone(self):
|
||||
@@ -189,7 +189,7 @@ class BaseInversion(object):
|
||||
"""
|
||||
printStoppers(self, self.stoppers)
|
||||
|
||||
@callHooks('finish')
|
||||
@utils.callHooks('finish')
|
||||
def finish(self):
|
||||
"""finish()
|
||||
|
||||
@@ -197,7 +197,7 @@ class BaseInversion(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def evalFunction(self, m, return_g=True, return_H=True):
|
||||
"""evalFunction(m, return_g=True, return_H=True)
|
||||
|
||||
@@ -207,7 +207,7 @@ class BaseInversion(object):
|
||||
u = self.prob.field(m)
|
||||
|
||||
if self._iter is 0 and self._beta is None:
|
||||
self._beta = self.beta0 = self.estimateBeta0(u=u)
|
||||
self._beta = self.beta0 = self.estimateBeta0(u=u,ratio=self.beta0_ratio)
|
||||
|
||||
phi_d = self.dataObj(m, u)
|
||||
phi_m = self.reg.modelObj(m)
|
||||
@@ -237,7 +237,7 @@ class BaseInversion(object):
|
||||
out += (operator,)
|
||||
return out if len(out) > 1 else out[0]
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def dataObj(self, m, u=None):
|
||||
"""dataObj(m, u=None)
|
||||
|
||||
@@ -257,10 +257,10 @@ class BaseInversion(object):
|
||||
"""
|
||||
# TODO: ensure that this is a data is vector and Wd is a matrix.
|
||||
R = self.Wd*self.prob.dataResidual(m, u=u)
|
||||
R = mkvc(R)
|
||||
R = utils.mkvc(R)
|
||||
return 0.5*np.vdot(R, R)
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def dataObjDeriv(self, m, u=None):
|
||||
"""dataObjDeriv(m, u=None)
|
||||
|
||||
@@ -302,7 +302,7 @@ class BaseInversion(object):
|
||||
|
||||
return dmisfit
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def dataObj2Deriv(self, m, v, u=None):
|
||||
"""dataObj2Deriv(m, v, u=None)
|
||||
|
||||
|
||||
+44
-34
@@ -1,9 +1,6 @@
|
||||
import numpy as np
|
||||
from SimPEG import Solver, utils, sp, np
|
||||
import matplotlib.pyplot as plt
|
||||
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt, callHooks
|
||||
norm = np.linalg.norm
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import Solver
|
||||
|
||||
|
||||
__all__ = ['Minimize', 'Remember', 'SteepestDescent', 'BFGS', 'GaussNewton', 'InexactGaussNewton', 'ProjectedGradient', 'NewtonRoot', 'StoppingCriteria', 'IterationPrinters']
|
||||
@@ -85,6 +82,8 @@ class Minimize(object):
|
||||
Minimize is a general class for derivative based optimization.
|
||||
"""
|
||||
|
||||
__metaclass__ = utils.Save.Savable
|
||||
|
||||
name = "General Optimization Algorithm" #: The name of the optimization algorithm
|
||||
|
||||
maxIter = 20 #: Maximum number of iterations
|
||||
@@ -110,9 +109,9 @@ class Minimize(object):
|
||||
self.printers = [IterationPrinters.iteration, IterationPrinters.f, IterationPrinters.norm_g, IterationPrinters.totalLS]
|
||||
self.printersLS = [IterationPrinters.iterationLS, IterationPrinters.LS_ft, IterationPrinters.LS_t, IterationPrinters.LS_armijoGoldstein]
|
||||
|
||||
setKwargs(self, **kwargs)
|
||||
utils.setKwargs(self, **kwargs)
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def minimize(self, evalFunction, x0):
|
||||
"""minimize(evalFunction, x0)
|
||||
|
||||
@@ -190,7 +189,7 @@ class Minimize(object):
|
||||
def parent(self, value):
|
||||
self._parent = value
|
||||
|
||||
@callHooks('startup')
|
||||
@utils.callHooks('startup')
|
||||
def startup(self, x0):
|
||||
"""
|
||||
**startup** is called at the start of any new minimize call.
|
||||
@@ -215,8 +214,8 @@ class Minimize(object):
|
||||
self.f_last = np.nan
|
||||
self.x_last = x0
|
||||
|
||||
@count
|
||||
@callHooks('doStartIteration')
|
||||
@utils.count
|
||||
@utils.callHooks('doStartIteration')
|
||||
def doStartIteration(self):
|
||||
"""doStartIteration()
|
||||
|
||||
@@ -238,9 +237,9 @@ class Minimize(object):
|
||||
"""
|
||||
pad = ' '*10 if inLS else ''
|
||||
name = self.name if not inLS else self.nameLS
|
||||
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
|
||||
utils.printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
|
||||
|
||||
@callHooks('printIter')
|
||||
@utils.callHooks('printIter')
|
||||
def printIter(self, inLS=False):
|
||||
"""
|
||||
**printIter** is called directly after function evaluations.
|
||||
@@ -250,7 +249,7 @@ class Minimize(object):
|
||||
|
||||
"""
|
||||
pad = ' '*10 if inLS else ''
|
||||
printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
|
||||
utils.printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
|
||||
|
||||
def printDone(self, inLS=False):
|
||||
"""
|
||||
@@ -263,10 +262,10 @@ class Minimize(object):
|
||||
pad = ' '*10 if inLS else ''
|
||||
stop, done = (' STOP! ', ' DONE! ') if not inLS else ('----------------', ' End Linesearch ')
|
||||
stoppers = self.stoppers if not inLS else self.stoppersLS
|
||||
printStoppers(self, stoppers, pad='', stop=stop, done=done)
|
||||
utils.printStoppers(self, stoppers, pad='', stop=stop, done=done)
|
||||
|
||||
|
||||
@callHooks('finish')
|
||||
@utils.callHooks('finish')
|
||||
def finish(self):
|
||||
"""finish()
|
||||
|
||||
@@ -282,10 +281,10 @@ class Minimize(object):
|
||||
if self._iter == 0:
|
||||
self.f0 = self.f
|
||||
self.g0 = self.g
|
||||
return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
|
||||
return utils.checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
|
||||
|
||||
@timeIt
|
||||
@callHooks('projection')
|
||||
@utils.timeIt
|
||||
@utils.callHooks('projection')
|
||||
def projection(self, p):
|
||||
"""projection(p)
|
||||
|
||||
@@ -299,7 +298,7 @@ class Minimize(object):
|
||||
"""
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def findSearchDirection(self):
|
||||
"""findSearchDirection()
|
||||
|
||||
@@ -330,7 +329,7 @@ class Minimize(object):
|
||||
"""
|
||||
return -self.g
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def scaleSearchDirection(self, p):
|
||||
"""scaleSearchDirection(p)
|
||||
|
||||
@@ -349,7 +348,7 @@ class Minimize(object):
|
||||
|
||||
nameLS = "Armijo linesearch" #: The line-search name
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def modifySearchDirection(self, p):
|
||||
"""modifySearchDirection(p)
|
||||
|
||||
@@ -387,7 +386,7 @@ class Minimize(object):
|
||||
|
||||
return self._LS_xt, self._iterLS < self.maxIterLS
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def modifySearchDirectionBreak(self, p):
|
||||
"""modifySearchDirectionBreak(p)
|
||||
|
||||
@@ -409,8 +408,8 @@ class Minimize(object):
|
||||
print 'The linesearch got broken. Boo.'
|
||||
return p, False
|
||||
|
||||
@count
|
||||
@callHooks('doEndIteration')
|
||||
@utils.count
|
||||
@utils.callHooks('doEndIteration')
|
||||
def doEndIteration(self, xt):
|
||||
"""doEndIteration(xt)
|
||||
|
||||
@@ -437,6 +436,8 @@ class Minimize(object):
|
||||
if getattr(self,'parent',None) is None:
|
||||
group.setArray('x', self.xc)
|
||||
else: # Assume inversion is the parent
|
||||
group.attrs['phi_d'] = self.parent.phi_d
|
||||
group.attrs['phi_m'] = self.parent.phi_m
|
||||
group.setArray('m', self.xc)
|
||||
group.setArray('dpred', self.parent.dpred)
|
||||
|
||||
@@ -526,7 +527,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
|
||||
self.aSet_prev = self.activeSet(x0)
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def projection(self, x):
|
||||
"""projection(x)
|
||||
|
||||
@@ -535,7 +536,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
"""
|
||||
return np.median(np.c_[self.lower,x,self.upper],axis=1)
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def activeSet(self, x):
|
||||
"""activeSet(x)
|
||||
|
||||
@@ -544,7 +545,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
"""
|
||||
return np.logical_or(x == self.lower, x == self.upper)
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def inactiveSet(self, x):
|
||||
"""inactiveSet(x)
|
||||
|
||||
@@ -553,7 +554,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
"""
|
||||
return np.logical_not(self.activeSet(x))
|
||||
|
||||
@count
|
||||
@utils.count
|
||||
def bindingSet(self, x):
|
||||
"""bindingSet(x)
|
||||
|
||||
@@ -566,7 +567,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
bind_low = np.logical_and(x == self.upper, self.g <= 0)
|
||||
return np.logical_or(bind_up, bind_low)
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def findSearchDirection(self):
|
||||
"""findSearchDirection()
|
||||
|
||||
@@ -611,7 +612,7 @@ class ProjectedGradient(Minimize, Remember):
|
||||
# aSet_after = self.activeSet(self.xc+p)
|
||||
return p
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def _doEndIteration_ProjectedGradient(self, xt):
|
||||
"""_doEndIteration_ProjectedGradient(xt)"""
|
||||
aSet = self.activeSet(xt)
|
||||
@@ -646,6 +647,9 @@ class BFGS(Minimize, Remember):
|
||||
name = 'BFGS'
|
||||
nbfgs = 10
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
Minimize.__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def bfgsH0(self):
|
||||
"""
|
||||
@@ -711,7 +715,10 @@ class BFGS(Minimize, Remember):
|
||||
class GaussNewton(Minimize, Remember):
|
||||
name = 'Gauss Newton'
|
||||
|
||||
@timeIt
|
||||
def __init__(self, **kwargs):
|
||||
Minimize.__init__(self, **kwargs)
|
||||
|
||||
@utils.timeIt
|
||||
def findSearchDirection(self):
|
||||
return Solver(self.H).solve(-self.g)
|
||||
|
||||
@@ -758,7 +765,7 @@ class InexactGaussNewton(BFGS, Minimize, Remember):
|
||||
def approxHinv(self, value):
|
||||
self._approxHinv = value
|
||||
|
||||
@timeIt
|
||||
@utils.timeIt
|
||||
def findSearchDirection(self):
|
||||
Hinv = Solver(self.H, doDirect=False, options={'iterSolver': 'CG', 'M': self.approxHinv, 'tol': self.tolCG, 'maxIter': self.maxIterCG})
|
||||
p = Hinv.solve(-self.g)
|
||||
@@ -768,7 +775,10 @@ class InexactGaussNewton(BFGS, Minimize, Remember):
|
||||
class SteepestDescent(Minimize, Remember):
|
||||
name = 'Steepest Descent'
|
||||
|
||||
@timeIt
|
||||
def __init__(self, **kwargs):
|
||||
Minimize.__init__(self, **kwargs)
|
||||
|
||||
@utils.timeIt
|
||||
def findSearchDirection(self):
|
||||
return -self.g
|
||||
|
||||
@@ -801,7 +811,7 @@ class NewtonRoot(object):
|
||||
doLS = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
setKwargs(self, **kwargs)
|
||||
utils.setKwargs(self, **kwargs)
|
||||
|
||||
def root(self, fun, x):
|
||||
"""root(fun, x)
|
||||
@@ -875,7 +885,7 @@ if __name__ == '__main__':
|
||||
|
||||
|
||||
print 'test the newtonRoot finding.'
|
||||
fun = lambda x, return_g=True: np.sin(x) if not return_g else ( np.sin(x), sdiag( np.cos(x) ) )
|
||||
fun = lambda x, return_g=True: np.sin(x) if not return_g else ( np.sin(x), utils.sdiag( np.cos(x) ) )
|
||||
x = np.array([np.pi-0.3, np.pi+0.1, 0])
|
||||
pnt = NewtonRoot(comments=True).root(fun,x)
|
||||
print pnt
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
from SimPEG.utils import mkvc
|
||||
|
||||
from SimPEG import utils
|
||||
|
||||
|
||||
class BaseMesh(object):
|
||||
@@ -12,6 +11,8 @@ class BaseMesh(object):
|
||||
:param numpy.array,list x0: Origin of the mesh (dim, )
|
||||
|
||||
"""
|
||||
__metaclass__ = utils.Save.Savable
|
||||
|
||||
def __init__(self, n, x0=None):
|
||||
|
||||
# Check inputs
|
||||
@@ -78,7 +79,7 @@ class BaseMesh(object):
|
||||
x_array = np.ones((x.size, len(x)))
|
||||
# Unwrap it and put it in a np array
|
||||
for i, xi in enumerate(x):
|
||||
x_array[:, i] = mkvc(xi)
|
||||
x_array[:, i] = utils.mkvc(xi)
|
||||
x = x_array
|
||||
|
||||
assert type(x) == np.ndarray, "x must be a numpy array"
|
||||
@@ -91,7 +92,7 @@ class BaseMesh(object):
|
||||
if format == 'M':
|
||||
return xx.reshape(nn, order='F')
|
||||
elif format == 'V':
|
||||
return mkvc(xx)
|
||||
return utils.mkvc(xx)
|
||||
|
||||
def switchKernal(xx):
|
||||
"""Switches over the different options."""
|
||||
@@ -101,7 +102,7 @@ class BaseMesh(object):
|
||||
return outKernal(xx, nn)
|
||||
elif xType in ['F', 'E']:
|
||||
# This will only deal with components of fields, not full 'F' or 'E'
|
||||
xx = mkvc(xx) # unwrap it in case it is a matrix
|
||||
xx = utils.mkvc(xx) # unwrap it in case it is a matrix
|
||||
nn = self.nFv if xType == 'F' else self.nEv
|
||||
nn = np.r_[0, nn]
|
||||
|
||||
@@ -308,7 +309,7 @@ class BaseMesh(object):
|
||||
"""
|
||||
fget = lambda self: np.array([x for x in [self.nNx, self.nNy, self.nNz] if not x is None])
|
||||
return locals()
|
||||
nNv = property(**nNv())
|
||||
nNv = property(**nNv())
|
||||
|
||||
def nEx():
|
||||
doc = """
|
||||
|
||||
+100
-2
@@ -7,6 +7,9 @@ try:
|
||||
except Exception, e:
|
||||
print 'Warning: SimPEG table needs h5py to be installed.'
|
||||
|
||||
|
||||
SAVEABLES = {}
|
||||
|
||||
def natural_keys(text):
|
||||
'''
|
||||
alist.sort(key=natural_keys) sorts in human order
|
||||
@@ -57,6 +60,7 @@ class SimPEGTable:
|
||||
def _doEndIteration_hdf5_inv(invObj):
|
||||
invObj.save(invObj._invNodeIt)
|
||||
postIteration(invObj._invNodeIt)
|
||||
self.f.flush()
|
||||
invObj.hook(_doEndIteration_hdf5_inv, overwrite=True)
|
||||
|
||||
# Delete all iterates that did not finish.
|
||||
@@ -78,11 +82,11 @@ class SimPEGTable:
|
||||
def _doEndIteration_hdf5_opt(optObj, xt):
|
||||
optObj.save(optObj._optNodeIt)
|
||||
postIteration(optObj._optNodeIt)
|
||||
self.f.flush()
|
||||
invObj.opt.hook(_doEndIteration_hdf5_opt, overwrite=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class hdf5Group(object):
|
||||
"""Has some low level support for wrapping the native HDF5-Group class"""
|
||||
|
||||
@@ -183,7 +187,7 @@ class hdf5Group(object):
|
||||
return s
|
||||
|
||||
def __str__(self):
|
||||
return '<%s "%s" (%d member%s)>' % (self.__class__.__name__, self.path, self.numChildren, '' if self.numChildren == 1 else 's')
|
||||
return '<%s "%s" (%d member%s, attrs=[%s])>' % (self.__class__.__name__, self.path, self.numChildren, '' if self.numChildren == 1 else 's',', '.join([a for a in self.attrs]))
|
||||
|
||||
|
||||
|
||||
@@ -204,3 +208,97 @@ class hdf5InversionIteration(hdf5Group):
|
||||
def __init__(self, T, groupNode):
|
||||
hdf5Group.__init__(self, T, groupNode)
|
||||
self.parentClass = hdf5Inversion
|
||||
|
||||
|
||||
|
||||
class Savable(type):
|
||||
def __new__(cls, name, bases, attrs):
|
||||
__init__ = attrs['__init__']
|
||||
def init_wrapper(self, *args, **kwargs):
|
||||
self._args_init = args
|
||||
self._kwargs_init = kwargs
|
||||
return __init__(self, *args,**kwargs)
|
||||
attrs['__init__'] = init_wrapper
|
||||
|
||||
newClass = super(Savable, cls).__new__(cls, name, bases, attrs)
|
||||
SAVEABLES[name] = newClass
|
||||
return newClass
|
||||
|
||||
|
||||
def saveSavable(obj, group):
|
||||
"""
|
||||
"""
|
||||
assert type(obj.__class__) is Savable, 'Can only save objects that are Savable objects.'
|
||||
|
||||
def doSave(grp, name, val):
|
||||
if type(val.__class__) is Savable:
|
||||
subgrp = grp.addGroup(name)
|
||||
saveInitArgs(val, subgrp)
|
||||
elif type(val) is np.ndarray:
|
||||
grp.setArray(name, val)
|
||||
elif type(val) in [list, tuple]:
|
||||
# Split up, and save each element
|
||||
for i, v in enumerate(val):
|
||||
doSave(grp, name + '[%d]'%i, v)
|
||||
else:
|
||||
# just try saving it as an attr
|
||||
grp.attrs[name] = val
|
||||
|
||||
group.attrs['__class__'] = obj.__class__.__name__
|
||||
for arg in obj._kwargs_init:
|
||||
doSave(group, '_kwarg_'+arg, obj._kwargs_init[arg])
|
||||
for i, arg in enumerate(obj._args_init):
|
||||
doSave(group, '_arg%d'%i, arg)
|
||||
|
||||
|
||||
def loadSavable(node):
|
||||
|
||||
args = ([a for a in node.attrs if '_arg' in a] + [a for a in node.children if '_arg' in a])
|
||||
kwargs = ([a for a in node.attrs if '_kwarg' in a] + [a for a in node.children if '_kwarg' in a])
|
||||
args.sort(key=utils.Save.natural_keys)
|
||||
kwargs.sort(key=utils.Save.natural_keys)
|
||||
|
||||
def get(node,key):
|
||||
if key in node.children: return node[key]
|
||||
elif key in node.attrs: return node.attrs[key]
|
||||
|
||||
ARGS = []
|
||||
for name in args:
|
||||
val = get(node, name)
|
||||
if val.__class__ is h5py.Dataset: val = val[:]
|
||||
if '[' in name: # We are reloading a list
|
||||
ind = int(name[4:name.index('[')])
|
||||
if len(ARGS) is ind: # Create the list
|
||||
ARGS.append([val])
|
||||
else:
|
||||
ARGS[ind].append(val)
|
||||
elif issubclass(val.__class__,hdf5Group):
|
||||
ARGS.append(load(val))
|
||||
else:
|
||||
ind = int(name[4:])
|
||||
ARGS.append(val)
|
||||
|
||||
KWARGS = {}
|
||||
for name in kwargs:
|
||||
val = get(node, name)
|
||||
if val.__class__ is h5py.Dataset: val = val[:]
|
||||
if '[' in name: # We are reloading a list
|
||||
key = name[7:name.index('[')]
|
||||
if key not in KWARGS: # Create the list
|
||||
KWARGS[key] = [val]
|
||||
else:
|
||||
KWARGS[key].append(val)
|
||||
elif issubclass(val.__class__,hdf5Group):
|
||||
key = name[7:]
|
||||
KWARGS[key] = load(val)
|
||||
else:
|
||||
key = name[7:]
|
||||
KWARGS[key] = val
|
||||
|
||||
cls = get(node, '__class__')
|
||||
if cls in SAVEABLES:
|
||||
return SAVEABLES[cls](*ARGS,**KWARGS)
|
||||
else:
|
||||
print 'Warning: %s Class not found in SimPEG.utils.Save.SAVABLES' % cls
|
||||
return (cls, ARGS, KWARGS)
|
||||
|
||||
|
||||
@@ -73,11 +73,9 @@ class Solver(object):
|
||||
Jacobi = sdiag(1.0/M[1].diagonal())
|
||||
options['M'] = Jacobi
|
||||
elif M[0] is 'GS':
|
||||
LL = sp.tril(M[1])
|
||||
UU = sp.triu(M[1])
|
||||
DD = sdiag(M[1].diagonal())
|
||||
Uinv = Solver(UU, flag='U')
|
||||
Linv = Solver(LL, flag='L')
|
||||
Uinv = Solver(M[1], flag='U')
|
||||
Linv = Solver(M[1], flag='L')
|
||||
def GS(f):
|
||||
return Uinv.solve(DD*Linv.solve(f))
|
||||
options['M'] = sp.linalg.LinearOperator( A.shape, GS, dtype=A.dtype )
|
||||
|
||||
@@ -45,7 +45,7 @@ def setKwargs(obj, **kwargs):
|
||||
setattr(obj, attr, kwargs[attr])
|
||||
else:
|
||||
raise Exception('%s attr is not recognized' % attr)
|
||||
hook(obj,callHooks, silent=True)
|
||||
|
||||
hook(obj,hook, silent=True)
|
||||
hook(obj,setKwargs, silent=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user