MeatClasses to make every SimPEG object saveable to an hdf5 file.

cleaned up imports in a lot of places.
Made solver not copy matrices around for GS preconditioning.
This commit is contained in:
rowanc1
2013-12-06 11:23:01 -08:00
parent c1767939bb
commit 5e0fb8642d
9 changed files with 188 additions and 76 deletions
+2
View File
@@ -1,3 +1,5 @@
import numpy as np
import scipy.sparse as sp
import utils
from utils import Solver
import mesh
+3
View File
@@ -9,6 +9,9 @@ import matplotlib.pyplot as plt
class LinearProblem(Problem):
"""docstring for LinearProblem"""
def __init__(self, *args, **kwargs):
Problem.__init__(self, *args, **kwargs)
def dpred(self, m, u=None):
return self.G.dot(m)
+10 -10
View File
@@ -1,6 +1,4 @@
import numpy as np
from SimPEG.utils import mkvc, sdiag, count, timeIt
import scipy.sparse as sp
from SimPEG import utils, np, sp
norm = np.linalg.norm
@@ -37,6 +35,8 @@ class Problem(object):
to (locally) find how model parameters change the data, and optimize!
"""
__metaclass__ = utils.Save.Savable
counter = None
def __init__(self, mesh):
@@ -85,7 +85,7 @@ class Problem(object):
def dobs(self, value):
self._dobs = value
@count
@utils.count
def dpred(self, m, u=None):
"""
Predicted data.
@@ -97,7 +97,7 @@ class Problem(object):
u = self.field(m)
return self.P*u
@count
@utils.count
def dataResidual(self, m, u=None):
"""
:param numpy.array m: geophysical model
@@ -117,7 +117,7 @@ class Problem(object):
return self.dpred(m, u=u) - self.dobs
@timeIt
@utils.timeIt
def J(self, m, v, u=None):
"""
:param numpy.array m: model
@@ -147,7 +147,7 @@ class Problem(object):
"""
raise NotImplementedError('J is not yet implemented.')
@timeIt
@utils.timeIt
def Jt(self, m, v, u=None):
"""
:param numpy.array m: model
@@ -161,7 +161,7 @@ class Problem(object):
raise NotImplementedError('Jt is not yet implemented.')
@timeIt
@utils.timeIt
def J_approx(self, m, v, u=None):
"""
@@ -176,7 +176,7 @@ class Problem(object):
"""
return self.J(m, v, u)
@timeIt
@utils.timeIt
def Jt_approx(self, m, v, u=None):
"""
:param numpy.array m: model
@@ -241,7 +241,7 @@ class Problem(object):
dobs = self.dpred(m,u=u)
noise = std*abs(dobs)*np.random.randn(*dobs.shape)
dobs = dobs+noise
eps = np.linalg.norm(mkvc(dobs),2)*1e-5
eps = np.linalg.norm(utils.mkvc(dobs),2)*1e-5
Wd = 1/(abs(dobs)*std+eps)
return dobs, Wd
+19 -19
View File
@@ -1,7 +1,5 @@
import numpy as np
import scipy.sparse as sp
import SimPEG
from SimPEG.utils import sdiag, mkvc, setKwargs, checkStoppers, printStoppers, count, timeIt, callHooks
from SimPEG import utils, sp, np
from Optimize import Remember
from BetaSchedule import Cooling
from SimPEG.inverse import IterationPrinters, StoppingCriteria
@@ -9,6 +7,8 @@ from SimPEG.inverse import IterationPrinters, StoppingCriteria
class BaseInversion(object):
"""docstring for BaseInversion"""
__metaclass__ = utils.Save.Savable
maxIter = 1 #: Maximum number of iterations
name = 'BaseInversion'
@@ -17,11 +17,11 @@ class BaseInversion(object):
comment = '' #: Used by some functions to indicate what is going on in the algorithm
counter = None #: Set this to a SimPEG.utils.Counter() if you want to count things
beta0 = None #: The initial Beta (regularization parameter)
beta0 = None #: The initial Beta (regularization parameter)
beta0_ratio = 0.1 #: When beta0 is set to None, estimateBeta0 is used with this ratio
def __init__(self, prob, reg, opt, **kwargs):
setKwargs(self, **kwargs)
utils.setKwargs(self, **kwargs)
self.prob = prob
self.reg = reg
self.opt = opt
@@ -46,7 +46,7 @@ class BaseInversion(object):
Standard deviation weighting matrix.
"""
if getattr(self,'_Wd',None) is None:
eps = np.linalg.norm(mkvc(self.prob.dobs),2)*1e-5
eps = np.linalg.norm(utils.mkvc(self.prob.dobs),2)*1e-5
self._Wd = 1/(abs(self.prob.dobs)*self.prob.std+eps)
return self._Wd
@Wd.setter
@@ -70,7 +70,7 @@ class BaseInversion(object):
def phi_d_target(self, value):
self._phi_d_target = value
@timeIt
@utils.timeIt
def run(self, m0):
"""run(m0)
@@ -89,7 +89,7 @@ class BaseInversion(object):
return self.m
@callHooks('startup')
@utils.callHooks('startup')
def startup(self, m0):
"""
**startup** is called at the start of any new run call.
@@ -109,7 +109,7 @@ class BaseInversion(object):
self.phi_d_last = np.nan
self.phi_m_last = np.nan
@callHooks('doStartIteration')
@utils.callHooks('doStartIteration')
def doStartIteration(self):
"""
**doStartIteration** is called at the end of each run iteration.
@@ -120,7 +120,7 @@ class BaseInversion(object):
self._beta = self.getBeta()
@callHooks('doEndIteration')
@utils.callHooks('doEndIteration')
def doEndIteration(self):
"""
**doEndIteration** is called at the end of each run iteration.
@@ -179,7 +179,7 @@ class BaseInversion(object):
def stoppingCriteria(self):
if self.debug: print 'checking stoppingCriteria'
return checkStoppers(self, self.stoppers)
return utils.checkStoppers(self, self.stoppers)
def printDone(self):
@@ -189,7 +189,7 @@ class BaseInversion(object):
"""
printStoppers(self, self.stoppers)
@callHooks('finish')
@utils.callHooks('finish')
def finish(self):
"""finish()
@@ -197,7 +197,7 @@ class BaseInversion(object):
"""
pass
@timeIt
@utils.timeIt
def evalFunction(self, m, return_g=True, return_H=True):
"""evalFunction(m, return_g=True, return_H=True)
@@ -207,7 +207,7 @@ class BaseInversion(object):
u = self.prob.field(m)
if self._iter is 0 and self._beta is None:
self._beta = self.beta0 = self.estimateBeta0(u=u)
self._beta = self.beta0 = self.estimateBeta0(u=u,ratio=self.beta0_ratio)
phi_d = self.dataObj(m, u)
phi_m = self.reg.modelObj(m)
@@ -237,7 +237,7 @@ class BaseInversion(object):
out += (operator,)
return out if len(out) > 1 else out[0]
@timeIt
@utils.timeIt
def dataObj(self, m, u=None):
"""dataObj(m, u=None)
@@ -257,10 +257,10 @@ class BaseInversion(object):
"""
# TODO: ensure that this is a data is vector and Wd is a matrix.
R = self.Wd*self.prob.dataResidual(m, u=u)
R = mkvc(R)
R = utils.mkvc(R)
return 0.5*np.vdot(R, R)
@timeIt
@utils.timeIt
def dataObjDeriv(self, m, u=None):
"""dataObjDeriv(m, u=None)
@@ -302,7 +302,7 @@ class BaseInversion(object):
return dmisfit
@timeIt
@utils.timeIt
def dataObj2Deriv(self, m, v, u=None):
"""dataObj2Deriv(m, v, u=None)
+44 -34
View File
@@ -1,9 +1,6 @@
import numpy as np
from SimPEG import Solver, utils, sp, np
import matplotlib.pyplot as plt
from SimPEG.utils import mkvc, sdiag, setKwargs, printTitles, printLine, printStoppers, checkStoppers, count, timeIt, callHooks
norm = np.linalg.norm
import scipy.sparse as sp
from SimPEG import Solver
__all__ = ['Minimize', 'Remember', 'SteepestDescent', 'BFGS', 'GaussNewton', 'InexactGaussNewton', 'ProjectedGradient', 'NewtonRoot', 'StoppingCriteria', 'IterationPrinters']
@@ -85,6 +82,8 @@ class Minimize(object):
Minimize is a general class for derivative based optimization.
"""
__metaclass__ = utils.Save.Savable
name = "General Optimization Algorithm" #: The name of the optimization algorithm
maxIter = 20 #: Maximum number of iterations
@@ -110,9 +109,9 @@ class Minimize(object):
self.printers = [IterationPrinters.iteration, IterationPrinters.f, IterationPrinters.norm_g, IterationPrinters.totalLS]
self.printersLS = [IterationPrinters.iterationLS, IterationPrinters.LS_ft, IterationPrinters.LS_t, IterationPrinters.LS_armijoGoldstein]
setKwargs(self, **kwargs)
utils.setKwargs(self, **kwargs)
@timeIt
@utils.timeIt
def minimize(self, evalFunction, x0):
"""minimize(evalFunction, x0)
@@ -190,7 +189,7 @@ class Minimize(object):
def parent(self, value):
self._parent = value
@callHooks('startup')
@utils.callHooks('startup')
def startup(self, x0):
"""
**startup** is called at the start of any new minimize call.
@@ -215,8 +214,8 @@ class Minimize(object):
self.f_last = np.nan
self.x_last = x0
@count
@callHooks('doStartIteration')
@utils.count
@utils.callHooks('doStartIteration')
def doStartIteration(self):
"""doStartIteration()
@@ -238,9 +237,9 @@ class Minimize(object):
"""
pad = ' '*10 if inLS else ''
name = self.name if not inLS else self.nameLS
printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
utils.printTitles(self, self.printers if not inLS else self.printersLS, name, pad)
@callHooks('printIter')
@utils.callHooks('printIter')
def printIter(self, inLS=False):
"""
**printIter** is called directly after function evaluations.
@@ -250,7 +249,7 @@ class Minimize(object):
"""
pad = ' '*10 if inLS else ''
printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
utils.printLine(self, self.printers if not inLS else self.printersLS, pad=pad)
def printDone(self, inLS=False):
"""
@@ -263,10 +262,10 @@ class Minimize(object):
pad = ' '*10 if inLS else ''
stop, done = (' STOP! ', ' DONE! ') if not inLS else ('----------------', ' End Linesearch ')
stoppers = self.stoppers if not inLS else self.stoppersLS
printStoppers(self, stoppers, pad='', stop=stop, done=done)
utils.printStoppers(self, stoppers, pad='', stop=stop, done=done)
@callHooks('finish')
@utils.callHooks('finish')
def finish(self):
"""finish()
@@ -282,10 +281,10 @@ class Minimize(object):
if self._iter == 0:
self.f0 = self.f
self.g0 = self.g
return checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
return utils.checkStoppers(self, self.stoppers if not inLS else self.stoppersLS)
@timeIt
@callHooks('projection')
@utils.timeIt
@utils.callHooks('projection')
def projection(self, p):
"""projection(p)
@@ -299,7 +298,7 @@ class Minimize(object):
"""
return p
@timeIt
@utils.timeIt
def findSearchDirection(self):
"""findSearchDirection()
@@ -330,7 +329,7 @@ class Minimize(object):
"""
return -self.g
@count
@utils.count
def scaleSearchDirection(self, p):
"""scaleSearchDirection(p)
@@ -349,7 +348,7 @@ class Minimize(object):
nameLS = "Armijo linesearch" #: The line-search name
@timeIt
@utils.timeIt
def modifySearchDirection(self, p):
"""modifySearchDirection(p)
@@ -387,7 +386,7 @@ class Minimize(object):
return self._LS_xt, self._iterLS < self.maxIterLS
@count
@utils.count
def modifySearchDirectionBreak(self, p):
"""modifySearchDirectionBreak(p)
@@ -409,8 +408,8 @@ class Minimize(object):
print 'The linesearch got broken. Boo.'
return p, False
@count
@callHooks('doEndIteration')
@utils.count
@utils.callHooks('doEndIteration')
def doEndIteration(self, xt):
"""doEndIteration(xt)
@@ -437,6 +436,8 @@ class Minimize(object):
if getattr(self,'parent',None) is None:
group.setArray('x', self.xc)
else: # Assume inversion is the parent
group.attrs['phi_d'] = self.parent.phi_d
group.attrs['phi_m'] = self.parent.phi_m
group.setArray('m', self.xc)
group.setArray('dpred', self.parent.dpred)
@@ -526,7 +527,7 @@ class ProjectedGradient(Minimize, Remember):
self.aSet_prev = self.activeSet(x0)
@count
@utils.count
def projection(self, x):
"""projection(x)
@@ -535,7 +536,7 @@ class ProjectedGradient(Minimize, Remember):
"""
return np.median(np.c_[self.lower,x,self.upper],axis=1)
@count
@utils.count
def activeSet(self, x):
"""activeSet(x)
@@ -544,7 +545,7 @@ class ProjectedGradient(Minimize, Remember):
"""
return np.logical_or(x == self.lower, x == self.upper)
@count
@utils.count
def inactiveSet(self, x):
"""inactiveSet(x)
@@ -553,7 +554,7 @@ class ProjectedGradient(Minimize, Remember):
"""
return np.logical_not(self.activeSet(x))
@count
@utils.count
def bindingSet(self, x):
"""bindingSet(x)
@@ -566,7 +567,7 @@ class ProjectedGradient(Minimize, Remember):
bind_low = np.logical_and(x == self.upper, self.g <= 0)
return np.logical_or(bind_up, bind_low)
@timeIt
@utils.timeIt
def findSearchDirection(self):
"""findSearchDirection()
@@ -611,7 +612,7 @@ class ProjectedGradient(Minimize, Remember):
# aSet_after = self.activeSet(self.xc+p)
return p
@timeIt
@utils.timeIt
def _doEndIteration_ProjectedGradient(self, xt):
"""_doEndIteration_ProjectedGradient(xt)"""
aSet = self.activeSet(xt)
@@ -646,6 +647,9 @@ class BFGS(Minimize, Remember):
name = 'BFGS'
nbfgs = 10
def __init__(self, **kwargs):
Minimize.__init__(self, **kwargs)
@property
def bfgsH0(self):
"""
@@ -711,7 +715,10 @@ class BFGS(Minimize, Remember):
class GaussNewton(Minimize, Remember):
name = 'Gauss Newton'
@timeIt
def __init__(self, **kwargs):
Minimize.__init__(self, **kwargs)
@utils.timeIt
def findSearchDirection(self):
return Solver(self.H).solve(-self.g)
@@ -758,7 +765,7 @@ class InexactGaussNewton(BFGS, Minimize, Remember):
def approxHinv(self, value):
self._approxHinv = value
@timeIt
@utils.timeIt
def findSearchDirection(self):
Hinv = Solver(self.H, doDirect=False, options={'iterSolver': 'CG', 'M': self.approxHinv, 'tol': self.tolCG, 'maxIter': self.maxIterCG})
p = Hinv.solve(-self.g)
@@ -768,7 +775,10 @@ class InexactGaussNewton(BFGS, Minimize, Remember):
class SteepestDescent(Minimize, Remember):
name = 'Steepest Descent'
@timeIt
def __init__(self, **kwargs):
Minimize.__init__(self, **kwargs)
@utils.timeIt
def findSearchDirection(self):
return -self.g
@@ -801,7 +811,7 @@ class NewtonRoot(object):
doLS = True
def __init__(self, **kwargs):
setKwargs(self, **kwargs)
utils.setKwargs(self, **kwargs)
def root(self, fun, x):
"""root(fun, x)
@@ -875,7 +885,7 @@ if __name__ == '__main__':
print 'test the newtonRoot finding.'
fun = lambda x, return_g=True: np.sin(x) if not return_g else ( np.sin(x), sdiag( np.cos(x) ) )
fun = lambda x, return_g=True: np.sin(x) if not return_g else ( np.sin(x), utils.sdiag( np.cos(x) ) )
x = np.array([np.pi-0.3, np.pi+0.1, 0])
pnt = NewtonRoot(comments=True).root(fun,x)
print pnt
+7 -6
View File
@@ -1,6 +1,5 @@
import numpy as np
from SimPEG.utils import mkvc
from SimPEG import utils
class BaseMesh(object):
@@ -12,6 +11,8 @@ class BaseMesh(object):
:param numpy.array,list x0: Origin of the mesh (dim, )
"""
__metaclass__ = utils.Save.Savable
def __init__(self, n, x0=None):
# Check inputs
@@ -78,7 +79,7 @@ class BaseMesh(object):
x_array = np.ones((x.size, len(x)))
# Unwrap it and put it in a np array
for i, xi in enumerate(x):
x_array[:, i] = mkvc(xi)
x_array[:, i] = utils.mkvc(xi)
x = x_array
assert type(x) == np.ndarray, "x must be a numpy array"
@@ -91,7 +92,7 @@ class BaseMesh(object):
if format == 'M':
return xx.reshape(nn, order='F')
elif format == 'V':
return mkvc(xx)
return utils.mkvc(xx)
def switchKernal(xx):
"""Switches over the different options."""
@@ -101,7 +102,7 @@ class BaseMesh(object):
return outKernal(xx, nn)
elif xType in ['F', 'E']:
# This will only deal with components of fields, not full 'F' or 'E'
xx = mkvc(xx) # unwrap it in case it is a matrix
xx = utils.mkvc(xx) # unwrap it in case it is a matrix
nn = self.nFv if xType == 'F' else self.nEv
nn = np.r_[0, nn]
@@ -308,7 +309,7 @@ class BaseMesh(object):
"""
fget = lambda self: np.array([x for x in [self.nNx, self.nNy, self.nNz] if not x is None])
return locals()
nNv = property(**nNv())
nNv = property(**nNv())
def nEx():
doc = """
+100 -2
View File
@@ -7,6 +7,9 @@ try:
except Exception, e:
print 'Warning: SimPEG table needs h5py to be installed.'
SAVEABLES = {}
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
@@ -57,6 +60,7 @@ class SimPEGTable:
def _doEndIteration_hdf5_inv(invObj):
invObj.save(invObj._invNodeIt)
postIteration(invObj._invNodeIt)
self.f.flush()
invObj.hook(_doEndIteration_hdf5_inv, overwrite=True)
# Delete all iterates that did not finish.
@@ -78,11 +82,11 @@ class SimPEGTable:
def _doEndIteration_hdf5_opt(optObj, xt):
optObj.save(optObj._optNodeIt)
postIteration(optObj._optNodeIt)
self.f.flush()
invObj.opt.hook(_doEndIteration_hdf5_opt, overwrite=True)
class hdf5Group(object):
"""Has some low level support for wrapping the native HDF5-Group class"""
@@ -183,7 +187,7 @@ class hdf5Group(object):
return s
def __str__(self):
return '<%s "%s" (%d member%s)>' % (self.__class__.__name__, self.path, self.numChildren, '' if self.numChildren == 1 else 's')
return '<%s "%s" (%d member%s, attrs=[%s])>' % (self.__class__.__name__, self.path, self.numChildren, '' if self.numChildren == 1 else 's',', '.join([a for a in self.attrs]))
@@ -204,3 +208,97 @@ class hdf5InversionIteration(hdf5Group):
def __init__(self, T, groupNode):
hdf5Group.__init__(self, T, groupNode)
self.parentClass = hdf5Inversion
class Savable(type):
def __new__(cls, name, bases, attrs):
__init__ = attrs['__init__']
def init_wrapper(self, *args, **kwargs):
self._args_init = args
self._kwargs_init = kwargs
return __init__(self, *args,**kwargs)
attrs['__init__'] = init_wrapper
newClass = super(Savable, cls).__new__(cls, name, bases, attrs)
SAVEABLES[name] = newClass
return newClass
def saveSavable(obj, group):
"""
"""
assert type(obj.__class__) is Savable, 'Can only save objects that are Savable objects.'
def doSave(grp, name, val):
if type(val.__class__) is Savable:
subgrp = grp.addGroup(name)
saveInitArgs(val, subgrp)
elif type(val) is np.ndarray:
grp.setArray(name, val)
elif type(val) in [list, tuple]:
# Split up, and save each element
for i, v in enumerate(val):
doSave(grp, name + '[%d]'%i, v)
else:
# just try saving it as an attr
grp.attrs[name] = val
group.attrs['__class__'] = obj.__class__.__name__
for arg in obj._kwargs_init:
doSave(group, '_kwarg_'+arg, obj._kwargs_init[arg])
for i, arg in enumerate(obj._args_init):
doSave(group, '_arg%d'%i, arg)
def loadSavable(node):
args = ([a for a in node.attrs if '_arg' in a] + [a for a in node.children if '_arg' in a])
kwargs = ([a for a in node.attrs if '_kwarg' in a] + [a for a in node.children if '_kwarg' in a])
args.sort(key=utils.Save.natural_keys)
kwargs.sort(key=utils.Save.natural_keys)
def get(node,key):
if key in node.children: return node[key]
elif key in node.attrs: return node.attrs[key]
ARGS = []
for name in args:
val = get(node, name)
if val.__class__ is h5py.Dataset: val = val[:]
if '[' in name: # We are reloading a list
ind = int(name[4:name.index('[')])
if len(ARGS) is ind: # Create the list
ARGS.append([val])
else:
ARGS[ind].append(val)
elif issubclass(val.__class__,hdf5Group):
ARGS.append(load(val))
else:
ind = int(name[4:])
ARGS.append(val)
KWARGS = {}
for name in kwargs:
val = get(node, name)
if val.__class__ is h5py.Dataset: val = val[:]
if '[' in name: # We are reloading a list
key = name[7:name.index('[')]
if key not in KWARGS: # Create the list
KWARGS[key] = [val]
else:
KWARGS[key].append(val)
elif issubclass(val.__class__,hdf5Group):
key = name[7:]
KWARGS[key] = load(val)
else:
key = name[7:]
KWARGS[key] = val
cls = get(node, '__class__')
if cls in SAVEABLES:
return SAVEABLES[cls](*ARGS,**KWARGS)
else:
print 'Warning: %s Class not found in SimPEG.utils.Save.SAVABLES' % cls
return (cls, ARGS, KWARGS)
+2 -4
View File
@@ -73,11 +73,9 @@ class Solver(object):
Jacobi = sdiag(1.0/M[1].diagonal())
options['M'] = Jacobi
elif M[0] is 'GS':
LL = sp.tril(M[1])
UU = sp.triu(M[1])
DD = sdiag(M[1].diagonal())
Uinv = Solver(UU, flag='U')
Linv = Solver(LL, flag='L')
Uinv = Solver(M[1], flag='U')
Linv = Solver(M[1], flag='L')
def GS(f):
return Uinv.solve(DD*Linv.solve(f))
options['M'] = sp.linalg.LinearOperator( A.shape, GS, dtype=A.dtype )
+1 -1
View File
@@ -45,7 +45,7 @@ def setKwargs(obj, **kwargs):
setattr(obj, attr, kwargs[attr])
else:
raise Exception('%s attr is not recognized' % attr)
hook(obj,callHooks, silent=True)
hook(obj,hook, silent=True)
hook(obj,setKwargs, silent=True)