mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 19:48:52 +08:00
Initial work on defaults.
This commit is contained in:
+50
-13
@@ -34,6 +34,18 @@ class Property(object):
|
||||
setattr(self, '_%sMap'%prop.name, val)
|
||||
return property(fget=fget, fset=fset, doc=prop.doc)
|
||||
|
||||
def _getDefaultProperty(self):
|
||||
prop = self
|
||||
def fget(self):
|
||||
return getattr(self, '_%sDefault'%prop.name, None)
|
||||
def fset(self, val):
|
||||
if prop.propertyLink is not None:
|
||||
linkName, linkMap = prop.propertyLink
|
||||
assert getattr(self, '%sDefault'%linkName, None) is None, 'Cannot set both sides of a linked property.'
|
||||
assert isinstance(val, np.ndarray) or np.isscalar(val), 'Default must be a scalar or a numpy array.'
|
||||
setattr(self, '_%sDefault'%prop.name, val)
|
||||
return property(fget=fget, fset=fset, doc=prop.doc)
|
||||
|
||||
def _getIndexProperty(self):
|
||||
prop = self
|
||||
def fget(self):
|
||||
@@ -47,12 +59,17 @@ class Property(object):
|
||||
def fget(self):
|
||||
mapping = getattr(self, '%sMap'%prop.name)
|
||||
if mapping is None and prop.propertyLink is None:
|
||||
return prop.defaultVal
|
||||
return getattr(self, '%sDefault'%prop.name)
|
||||
|
||||
if mapping is None and prop.propertyLink is not None:
|
||||
linkName, linkMapClass = prop.propertyLink
|
||||
linkMap = linkMapClass(None)
|
||||
if getattr(self, '%sMap'%linkName, None) is None:
|
||||
# *
|
||||
print linkName, getattr(self.propMap, '_%sDefault'%linkName, None)
|
||||
if getattr(self, '%sMap'%linkName, None) is None and getattr(self.propMap, '_%sDefault'%linkName, None) is not None:
|
||||
# We have a default
|
||||
return linkMap * getattr(self, '%sDefault'%linkName, None)
|
||||
elif getattr(self, '%sMap'%linkName, None) is None:
|
||||
return prop.defaultVal
|
||||
m = getattr(self, '%s'%linkName)
|
||||
return linkMap * m
|
||||
@@ -110,6 +127,12 @@ class Property(object):
|
||||
return getattr(self.propMap, '_%sMap'%prop.name, None)
|
||||
return property(fget=fget)
|
||||
|
||||
def _getModelDefaultProperty(self):
|
||||
prop = self
|
||||
def fget(self):
|
||||
return getattr(self.propMap, '_%sDefault'%prop.name, prop.defaultVal)
|
||||
return property(fget=fget)
|
||||
|
||||
|
||||
|
||||
class PropModel(object):
|
||||
@@ -150,8 +173,9 @@ class _PropMapMetaClass(type):
|
||||
for attr in keys:
|
||||
if isinstance(attrs[attr], Property):
|
||||
attrs[attr].name = attr
|
||||
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
|
||||
attrs[attr + 'Index'] = attrs[attr]._getIndexProperty()
|
||||
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
|
||||
attrs[attr + 'Default'] = attrs[attr]._getDefaultProperty()
|
||||
attrs[attr + 'Index' ] = attrs[attr]._getIndexProperty()
|
||||
_properties[attr] = attrs[attr]
|
||||
attrs.pop(attr)
|
||||
|
||||
@@ -181,11 +205,12 @@ class _PropMapMetaClass(type):
|
||||
for attr in _properties:
|
||||
prop = _properties[attr]
|
||||
|
||||
attrs[attr ] = prop._getProperty()
|
||||
attrs[attr + 'Map' ] = prop._getModelMapProperty()
|
||||
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
|
||||
attrs[attr + 'Model'] = prop._getModelProperty()
|
||||
attrs[attr + 'Deriv'] = prop._getModelDerivProperty()
|
||||
attrs[attr ] = prop._getProperty()
|
||||
attrs[attr + 'Map' ] = prop._getModelMapProperty()
|
||||
attrs[attr + 'Default'] = prop._getModelDefaultProperty()
|
||||
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
|
||||
attrs[attr + 'Model' ] = prop._getModelProperty()
|
||||
attrs[attr + 'Deriv' ] = prop._getModelDerivProperty()
|
||||
|
||||
return type(name.replace('PropMap', 'PropModel'), (PropModel, ), attrs)
|
||||
|
||||
@@ -198,8 +223,8 @@ class PropMap(object):
|
||||
PropMap takes a multi parameter model and maps it to the equivalent PropModel
|
||||
"""
|
||||
if type(mappings) is dict:
|
||||
assert np.all([k in ['maps', 'slices'] for k in mappings]), 'Dict must only have properties "maps" and "slices"'
|
||||
self.setup(mappings['maps'], slices=mappings['slices'])
|
||||
assert np.all([k in ['maps', 'slices', 'defaults'] for k in mappings]), 'Dict must only have properties "maps", "slices" and "defaults"'
|
||||
self.setup(mappings['maps'], slices=mappings.get('slices',{}), defaults=mappings.get('defaults',{}))
|
||||
elif type(mappings) is list:
|
||||
self.setup(mappings)
|
||||
elif isinstance(mappings, Maps.IdentityMap):
|
||||
@@ -208,7 +233,7 @@ class PropMap(object):
|
||||
raise Exception('mappings must be a dict, a mapping, or a list of tuples.')
|
||||
|
||||
|
||||
def setup(self, maps, slices=None):
|
||||
def setup(self, maps, slices=None, defaults=None):
|
||||
"""
|
||||
Sets up the maps and slices for the PropertyMap
|
||||
|
||||
@@ -231,6 +256,13 @@ class PropMap(object):
|
||||
s in self._properties and
|
||||
(type(slices[s]) in [slice, list] or isinstance(slices[s], np.ndarray))
|
||||
for s in slices]), 'Slices must be for each property'
|
||||
if defaults is None:
|
||||
defaults = dict()
|
||||
else:
|
||||
assert np.all([
|
||||
s in self._properties and
|
||||
(np.isscalar(defaults[s]) or isinstance(defaults[s], np.ndarray))
|
||||
for s in defaults]), 'Defaults must be for each property'
|
||||
|
||||
self.clearMaps()
|
||||
|
||||
@@ -239,7 +271,12 @@ class PropMap(object):
|
||||
setattr(self, '%sMap'%name, mapping)
|
||||
setattr(self, '%sIndex'%name, slices.get(name, slice(nP, nP + mapping.nP)))
|
||||
nP += mapping.nP
|
||||
self.nP = nP
|
||||
self.nP = nP
|
||||
|
||||
for key in defaults:
|
||||
setattr(self, '%sDefault'%key, defaults[key])
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def defaultInvProp(self):
|
||||
|
||||
@@ -56,8 +56,22 @@ class TestPropMaps(unittest.TestCase):
|
||||
assert np.all(m.sigma == np.exp(np.r_[1.,2,3]))
|
||||
assert m.sigmaDeriv is not None
|
||||
|
||||
assert m.mu == mu_0
|
||||
|
||||
assert m.nP == 3
|
||||
|
||||
def test_defaultOverride(self):
|
||||
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
|
||||
PM = MyReciprocalPropMap({'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2}})
|
||||
self.assertRaises(Exception, MyReciprocalPropMap, {'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2, 'mui':5}}) # Cannot set both sides of the default
|
||||
|
||||
m = PM(np.r_[1.,2,3])
|
||||
assert np.all(m.sigmaModel == np.r_[1,2,3])
|
||||
|
||||
self.assertEqual(m.mu, mu_0 * 2)
|
||||
# self.assertEqual(m.mui, 1/(mu_0 * 2))
|
||||
|
||||
|
||||
def test_slices(self):
|
||||
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
|
||||
PM = MyPropMap({'maps':[('sigma', expMap)], 'slices':{'sigma':[2,1,0]}})
|
||||
|
||||
Reference in New Issue
Block a user