mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-28 19:33:28 +08:00
Initial work on pickling for #226
This commit is contained in:
+30
-3
@@ -110,8 +110,6 @@ class Property(object):
|
||||
return getattr(self.propMap, '_%sMap'%prop.name, None)
|
||||
return property(fget=fget)
|
||||
|
||||
|
||||
|
||||
class PropModel(object):
|
||||
def __init__(self, propMap, vector):
|
||||
self.propMap = propMap
|
||||
@@ -189,6 +187,12 @@ class _PropMapMetaClass(type):
|
||||
|
||||
return type(name.replace('PropMap', 'PropModel'), (PropModel, ), attrs)
|
||||
|
||||
def fromPickle(name, properties, maps, slices):
|
||||
attrs = dict()
|
||||
for p in properties:
|
||||
attrs[p] = Property(**properties[p])
|
||||
PM = type(name, (PropMap,), attrs)
|
||||
return PM(dict(maps=maps, slices=slices))
|
||||
|
||||
class PropMap(object):
|
||||
__metaclass__ = _PropMapMetaClass
|
||||
@@ -197,6 +201,7 @@ class PropMap(object):
|
||||
"""
|
||||
PropMap takes a multi parameter model and maps it to the equivalent PropModel
|
||||
"""
|
||||
|
||||
if type(mappings) is dict:
|
||||
assert np.all([k in ['maps', 'slices'] for k in mappings]), 'Dict must only have properties "maps" and "slices"'
|
||||
self.setup(mappings['maps'], slices=mappings['slices'])
|
||||
@@ -239,7 +244,11 @@ class PropMap(object):
|
||||
setattr(self, '%sMap'%name, mapping)
|
||||
setattr(self, '%sIndex'%name, slices.get(name, slice(nP, nP + mapping.nP)))
|
||||
nP += mapping.nP
|
||||
self.nP = nP
|
||||
|
||||
self._maps = maps
|
||||
self._slices = slices
|
||||
|
||||
self.nP = nP
|
||||
|
||||
@property
|
||||
def defaultInvProp(self):
|
||||
@@ -253,9 +262,27 @@ class PropMap(object):
|
||||
setattr(self, '%sMap'%name, None)
|
||||
setattr(self, '%sIndex'%name, None)
|
||||
|
||||
self._maps = None
|
||||
self._slices = None
|
||||
|
||||
def __call__(self, vec):
|
||||
return self.PropModel(self, vec)
|
||||
|
||||
def __contains__(self, val):
|
||||
activeMaps = [name for name in self._properties if getattr(self, '%sMap'%name) is not None]
|
||||
return val in activeMaps
|
||||
|
||||
def __reduce__(self):
|
||||
|
||||
import cPickle
|
||||
props = dict()
|
||||
for p in self._properties:
|
||||
props[p] = self._properties[p].toJSON()
|
||||
className = self.__class__.__name__
|
||||
|
||||
|
||||
pickledMaps = []
|
||||
for name, mapping in self._maps:
|
||||
pickledMaps += [name, cPickle.dumps(mapping)]
|
||||
|
||||
return (fromPickle, (className, props, self._maps, self._slices))
|
||||
|
||||
@@ -65,8 +65,8 @@ def setKwargs(obj, ignore=[], **kwargs):
|
||||
else:
|
||||
raise Exception('%s attr is not recognized' % attr)
|
||||
|
||||
hook(obj,hook, silent=True)
|
||||
hook(obj,setKwargs, silent=True)
|
||||
# hook(obj,hook, silent=True)
|
||||
# hook(obj,setKwargs, silent=True)
|
||||
|
||||
def printTitles(obj, printers, name='Print Titles', pad=''):
|
||||
titles = ''
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from SimPEG import *
|
||||
from scipy.constants import mu_0
|
||||
import cPickle
|
||||
|
||||
|
||||
class MyPropMap(Maps.PropMap):
|
||||
@@ -28,9 +29,10 @@ class TestPropMaps(unittest.TestCase):
|
||||
PM3 = MyPropMap({'maps':[('sigma', expMap)], 'slices':{'sigma':slice(0,3)}})
|
||||
|
||||
for PM in [PM1,PM2,PM3]:
|
||||
PM = cPickle.loads( cPickle.dumps(PM) )
|
||||
assert PM.defaultInvProp == 'sigma'
|
||||
assert PM.sigmaMap is not None
|
||||
assert PM.sigmaMap is expMap
|
||||
assert PM.sigmaMap.__class__ is expMap.__class__
|
||||
assert PM.sigmaIndex == slice(0,3)
|
||||
assert getattr(PM, 'sigma', None) is None
|
||||
assert PM.muMap is None
|
||||
@@ -52,7 +54,7 @@ class TestPropMaps(unittest.TestCase):
|
||||
assert m.muDeriv is None
|
||||
|
||||
assert np.all(m.sigmaModel == np.r_[1.,2,3])
|
||||
assert m.sigmaMap is expMap
|
||||
assert m.sigmaMap.__class__ is expMap.__class__
|
||||
assert np.all(m.sigma == np.exp(np.r_[1.,2,3]))
|
||||
assert m.sigmaDeriv is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user