diff --git a/SimPEG/Data.py b/SimPEG/Data.py index 7a793065..7b0145dc 100644 --- a/SimPEG/Data.py +++ b/SimPEG/Data.py @@ -4,7 +4,7 @@ import Utils, numpy as np class BaseData(object): """Data holds the observed data, and the standard deviations.""" - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass std = None #: Estimated Standard Deviations dobs = None #: Observed data diff --git a/SimPEG/Inversion.py b/SimPEG/Inversion.py index 56a00f8f..64445ab2 100644 --- a/SimPEG/Inversion.py +++ b/SimPEG/Inversion.py @@ -7,7 +7,7 @@ class BaseInversion(object): """BaseInversion(objFunc, opt, **kwargs) """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass name = 'BaseInversion' diff --git a/SimPEG/Mesh/LogicallyOrthogonalMesh.py b/SimPEG/Mesh/LogicallyOrthogonalMesh.py index e39c0f54..a8b4207f 100644 --- a/SimPEG/Mesh/LogicallyOrthogonalMesh.py +++ b/SimPEG/Mesh/LogicallyOrthogonalMesh.py @@ -26,7 +26,7 @@ class LogicallyOrthogonalMesh(BaseMesh, DiffOperators, InnerProducts, LomView): M.plotGrid(showIt=True) """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass _meshType = 'LOM' diff --git a/SimPEG/Mesh/TensorMesh.py b/SimPEG/Mesh/TensorMesh.py index 341cfba7..b2981bdd 100644 --- a/SimPEG/Mesh/TensorMesh.py +++ b/SimPEG/Mesh/TensorMesh.py @@ -33,7 +33,7 @@ class TensorMesh(BaseMesh, TensorView, DiffOperators, InnerProducts): """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass _meshType = 'TENSOR' diff --git a/SimPEG/Model.py b/SimPEG/Model.py index 44dc3908..6941607d 100644 --- a/SimPEG/Model.py +++ b/SimPEG/Model.py @@ -7,7 +7,7 @@ class BaseModel(object): """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass counter = None #: A SimPEG.Utils.Counter object mesh = None #: A SimPEG Mesh diff --git a/SimPEG/ObjFunction.py b/SimPEG/ObjFunction.py index 91722b68..7d08b9fc 100644 --- a/SimPEG/ObjFunction.py +++ b/SimPEG/ObjFunction.py @@ -3,7 +3,7 @@ import Utils, Parameters, numpy as np, scipy.sparse as sp class BaseObjFunction(object): """BaseObjFunction(data, reg, **kwargs)""" - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass beta = Parameters.ParameterProperty('beta', default=1, doc='Regularization trade-off parameter') diff --git a/SimPEG/Optimization.py b/SimPEG/Optimization.py index 4e5963f6..8e08250b 100644 --- a/SimPEG/Optimization.py +++ b/SimPEG/Optimization.py @@ -82,7 +82,7 @@ class Minimize(object): Minimize is a general class for derivative based optimization. """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass name = "General Optimization Algorithm" #: The name of the optimization algorithm diff --git a/SimPEG/Problem.py b/SimPEG/Problem.py index df37d05a..314af81f 100644 --- a/SimPEG/Problem.py +++ b/SimPEG/Problem.py @@ -34,7 +34,7 @@ class BaseProblem(object): to (locally) find how model parameters change the data, and optimize! """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass counter = None #: A SimPEG.Utils.Counter object diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index aa62449a..b360d726 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -10,7 +10,7 @@ class BaseRegularization(object): """ - __metaclass__ = Utils.Save.Savable + __metaclass__ = Utils.SimPEGMetaClass modelPair = Model.BaseModel #: Some regularizations only work on specific models diff --git a/SimPEG/Utils/Save.py b/SimPEG/Utils/Save.py deleted file mode 100644 index a9c77191..00000000 --- a/SimPEG/Utils/Save.py +++ /dev/null @@ -1,352 +0,0 @@ -import numpy as np -import time -import re - -try: - import h5py -except Exception, e: - print 'Warning: SimPEG.Utils.Save needs h5py to be installed.' - - -SAVEABLES = {} - -def natural_keys(text): - ''' - alist.sort(key=natural_keys) sorts in human order - http://nedbatchelder.com/blog/200712/human_sorting.html - (See Toothy's implementation in the comments) - ''' - atoi = lambda text: int(text) if text.isdigit() else text - return [ atoi(c) for c in re.split('(\d+)', text) ] - - -def preIteration(group): - group.attrs['complete'] = False - group.attrs['time'] = time.time() - -def postIteration(group): - group.attrs['time'] = time.time() - group.attrs['time'] - group.attrs['date'] = time.ctime() - group.attrs['complete'] = True - -class SimPEGTable: - """ - This is a wrapper class on the HDF5 file. - """ - def __init__(self, name, mode='a'): - if '.hdf5' not in name: - name += '.hdf5' - self.f = h5py.File(name, mode) - self.root = hdf5Group(self,self.f) - - self.inversions = hdf5InversionGroup(self,self.root.addGroup('inversions',soft=True)) - - def show(self): self.root.show() - - def saveInversion(self, invObj): - - # Create a new inversion anytime this is run. - def _startup_hdf5_inv(invObj, m0): - node = self.inversions.addGroup('%d'%self.inversions.numChildren) - saveSavable(invObj,node.addGroup('rebuild')) - results = node.addGroup('results') - preIteration(results) - invObj._invNode = results - self.f.flush() - invObj.hook(_startup_hdf5_inv, overwrite=True) - - # At the start of every iteration we will create a inversion iteration node. - def _doStartIteration_hdf5_inv(invObj): - invObj._invNodeIt = invObj._invNode.addGroup('%d'%(invObj.iter+1)) - preIteration(invObj._invNodeIt) - invObj.hook(_doStartIteration_hdf5_inv, overwrite=True) - - def _doEndIteration_hdf5_inv(invObj): - invObj.save(invObj._invNodeIt) - postIteration(invObj._invNodeIt) - self.f.flush() - invObj.hook(_doEndIteration_hdf5_inv, overwrite=True) - - # Delete all iterates that did not finish. - def _finish_hdf5_inv(invObj): - postIteration(invObj._invNode) - for it in invObj._invNode: - if not it.attrs['complete']: - del self.f[it.path] - del invObj._invNode - self.f.flush() - invObj.hook(_finish_hdf5_inv, overwrite=True) - - def _doStartIteration_hdf5_opt(optObj): - optObj._optNodeIt = optObj.parent._invNode.addGroup('%d.%d'%(optObj.parent.iter, optObj.iter)) - preIteration(optObj._optNodeIt) - invObj.opt.hook(_doStartIteration_hdf5_opt, overwrite=True) - - def _doEndIteration_hdf5_opt(optObj, xt): - optObj.save(optObj._optNodeIt) - postIteration(optObj._optNodeIt) - self.f.flush() - invObj.opt.hook(_doEndIteration_hdf5_opt, overwrite=True) - - - -class hdf5Group(object): - """Has some low level support for wrapping the native HDF5-Group class""" - - def __init__(self, T, groupNode): - self.T = T - # check if you are inputing a hdf5Group rather than a raw node, and act accordingly - if issubclass(groupNode.__class__, hdf5Group): - self.node = groupNode.node - else: - self.node = groupNode - - self.childClass = hdf5Group - self.parentClass = hdf5Group - - @property - def children(self): - """Children names in a list - - Use obj[name] to return the actual node. - """ - myChildren = [c for c in self.node] - myChildren.sort(key=natural_keys) - return myChildren - - @property - def numChildren(self): - """Returns the len(children)""" - return len(self.children) - - @property - def parent(self): - """Returns parent node""" - return self.parentClass(self.T, self.node.parent) - - @property - def name(self): - return self.path.split('/')[-1] - - @property - def path(self): - """Returns the root path of the group""" - return self.node.name - - @property - def attrs(self): - """Returns a list of attributes in the group""" - return self.node.attrs - - def addGroup(self, name, soft=False): - """Adds a child group to the current node.""" - if name in self.children and soft: - return self[name] - assert name not in self.children, 'Already a child called: '+self.path+'/'+name - return self.childClass(self.T, self.node.create_group(name)) - - def setArray(self, name, data): - a = self.node.create_dataset(name, data.shape) - a[...] = data - return a - - def __getitem__(self, val): - if type(val) is int: - val = self.children[val] - child = self.node[val] - if type(child) is h5py.Group: - child = self.childClass(self.T, child) - return child - - def __contains__(self, key): - return key in self.children - - def show(self, pad='', maxDepth=1, depth=0): - """ - Recursively show the structure of the database. - - For example:: - - - - - - - - - - - - - """ - s = self.__str__() - pad += ' '*4 - if maxDepth <= 0: print s - if depth >= maxDepth: return s - - for c in self.children: - if issubclass(self[c].__class__, hdf5Group): - s += '\n%s- %s' % (pad, self[c].show(pad=pad,depth=depth+1,maxDepth=maxDepth)) - else: - s += '\n%s- %s' % (pad, self[c].__str__()) - if depth is 0: - print s - else: - return s - - def __str__(self): - return '<%s "%s" (%d member%s, attrs=[%s])>' % (self.__class__.__name__, self.path, self.numChildren, '' if self.numChildren == 1 else 's',', '.join([a for a in self.attrs])) - - - -class hdf5InversionGroup(hdf5Group): - def __init__(self, T, groupNode): - hdf5Group.__init__(self, T, groupNode) - self.childClass = hdf5Inversion - -class hdf5Inversion(hdf5Group): - def __init__(self, T, groupNode): - hdf5Group.__init__(self, T, groupNode) - self.parentClass = hdf5InversionGroup - self.childClass = hdf5InversionResults - - def rebuild(self): - return loadSavable(self['rebuild']) - - @property - def results(self): return self['results'] - - -class hdf5InversionResults(hdf5Group): - def __init__(self, T, groupNode): - hdf5Group.__init__(self, T, groupNode) - self.parentClass = hdf5Inversion - self.childClass = hdf5InversionIteration - -class hdf5InversionIteration(hdf5Group): - def __init__(self, T, groupNode): - hdf5Group.__init__(self, T, groupNode) - self.parentClass = hdf5InversionResults - - - -class Savable(type): - def __new__(cls, name, bases, attrs): - __init__ = attrs['__init__'] - def init_wrapper(self, *args, **kwargs): - self._args_init = args - self._kwargs_init = kwargs - return __init__(self, *args,**kwargs) - attrs['__init__'] = init_wrapper - - newClass = super(Savable, cls).__new__(cls, name, bases, attrs) - SAVEABLES[name] = newClass - return newClass - - -def saveSavable(obj, group, debug=False): - """ - This creates softlinks if _savable exists in children object. - - The first object is always created. - """ - assert type(obj.__class__) is Savable, 'Can only save objects that are Savable objects.' - - def doSave(grp, name, val): - if debug: print name, val - if type(val.__class__) is Savable: - link = getattr(val,'_savable',None) - if link is not None: - group.node[name] = h5py.SoftLink(link.path) - if debug: 'Created a softlink path to %s' % link.path - else: - subgrp = grp.addGroup(name) - saveSavable(val, subgrp, debug=debug) - elif type(val) in [list, tuple]: - # Split up, and save each element - for i, v in enumerate(val): - doSave(grp, name + '[%d]'%i, v) - elif type(val) is np.ndarray: - grp.setArray(name, val) - elif val is None: - grp.attrs[name] = 'None' - else: - # just try saving it as an attr - try: - grp.attrs[name] = val - except Exception, e: - print 'Warning: Could not save %s, problems may arise in loading.' % name - - group.attrs['__class__'] = obj.__class__.__name__ - for arg in obj._kwargs_init: - doSave(group, '_kwarg_'+arg, obj._kwargs_init[arg]) - for i, arg in enumerate(obj._args_init): - doSave(group, '_arg%d'%i, arg) - obj._savable = group - - -def loadSavable(node, pointers=None): - """ - pointers allow things that point to the same node in the h5py file to - be returned as the same object, if they have already been created. - """ - - if pointers is None: pointers = [] - for pointer in pointers: - if pointer._savable.node == node.node: return pointer - - args = ([a for a in node.attrs if '_arg' in a] + [a for a in node.children if '_arg' in a]) - kwargs = ([a for a in node.attrs if '_kwarg' in a] + [a for a in node.children if '_kwarg' in a]) - args.sort(key=natural_keys) - kwargs.sort(key=natural_keys) - - def get(node,key): - if key in node.children: return node[key] - elif key in node.attrs: return node.attrs[key] - - ARGS = [] - for name in args: - val = get(node, name) - if val.__class__ is h5py.Dataset: val = val[:] - if val is 'None': val = None - if '[' in name: # We are reloading a list - ind = int(name[4:name.index('[')]) - if len(ARGS) is ind: # Create the list - ARGS.append([val]) - else: - ARGS[ind].append(val) - elif issubclass(val.__class__,hdf5Group): - ARGS.append(loadSavable(val,pointers=pointers)) - else: - ind = int(name[4:]) - ARGS.append(val) - - KWARGS = {} - for name in kwargs: - val = get(node, name) - if val.__class__ is h5py.Dataset: val = val[:] - if val is 'None': val = None - if '[' in name: # We are reloading a list - key = name[7:name.index('[')] - if key not in KWARGS: # Create the list - KWARGS[key] = [val] - else: - KWARGS[key].append(val) - elif issubclass(val.__class__,hdf5Group): - key = name[7:] - KWARGS[key] = loadSavable(val,pointers=pointers) - else: - key = name[7:] - KWARGS[key] = val - - cls = get(node, '__class__') - if cls in SAVEABLES: - try: - out = SAVEABLES[cls](*ARGS, **KWARGS) - out._savable = node - pointers.append(out) # Because this is recursive. - return out - except Exception, e: - print 'Warning: %s Class could not be initiated.' % cls - print 'ARGS: ', ARGS - print 'KWARGS: ', KWARGS - return (cls, ARGS, KWARGS, node) - else: - print 'Warning: %s Class not found in SimPEG.Utils.Save.SAVABLES' % cls - return (cls, ARGS, KWARGS, node) - diff --git a/SimPEG/Utils/__init__.py b/SimPEG/Utils/__init__.py index 1331cc8c..8ed65c84 100644 --- a/SimPEG/Utils/__init__.py +++ b/SimPEG/Utils/__init__.py @@ -4,7 +4,6 @@ from meshutils import exampleLomGird, meshTensors from lomutils import volTetra, faceInfo, inv2X2BlockDiagonal, inv3X3BlockDiagonal, indexCube from interputils import interpmat from ipythonutils import easyAnimate as animate -import Save import ModelBuilder import types @@ -12,6 +11,12 @@ import time import numpy as np from functools import wraps + +class SimPEGMetaClass(type): + def __new__(cls, name, bases, attrs): + return super(SimPEGMetaClass, cls).__new__(cls, name, bases, attrs) + + def hook(obj, method, name=None, overwrite=False, silent=False): """ This dynamically binds a method to the instance of the class.