Initial work on defaults.

This commit is contained in:
Rowan Cockett
2016-02-15 10:51:43 -08:00
parent 463b9b6164
commit 4776544b30
2 changed files with 64 additions and 13 deletions
+50 -13
View File
@@ -34,6 +34,18 @@ class Property(object):
setattr(self, '_%sMap'%prop.name, val)
return property(fget=fget, fset=fset, doc=prop.doc)
def _getDefaultProperty(self):
prop = self
def fget(self):
return getattr(self, '_%sDefault'%prop.name, None)
def fset(self, val):
if prop.propertyLink is not None:
linkName, linkMap = prop.propertyLink
assert getattr(self, '%sDefault'%linkName, None) is None, 'Cannot set both sides of a linked property.'
assert isinstance(val, np.ndarray) or np.isscalar(val), 'Default must be a scalar or a numpy array.'
setattr(self, '_%sDefault'%prop.name, val)
return property(fget=fget, fset=fset, doc=prop.doc)
def _getIndexProperty(self):
prop = self
def fget(self):
@@ -47,12 +59,17 @@ class Property(object):
def fget(self):
mapping = getattr(self, '%sMap'%prop.name)
if mapping is None and prop.propertyLink is None:
return prop.defaultVal
return getattr(self, '%sDefault'%prop.name)
if mapping is None and prop.propertyLink is not None:
linkName, linkMapClass = prop.propertyLink
linkMap = linkMapClass(None)
if getattr(self, '%sMap'%linkName, None) is None:
# *
print linkName, getattr(self.propMap, '_%sDefault'%linkName, None)
if getattr(self, '%sMap'%linkName, None) is None and getattr(self.propMap, '_%sDefault'%linkName, None) is not None:
# We have a default
return linkMap * getattr(self, '%sDefault'%linkName, None)
elif getattr(self, '%sMap'%linkName, None) is None:
return prop.defaultVal
m = getattr(self, '%s'%linkName)
return linkMap * m
@@ -110,6 +127,12 @@ class Property(object):
return getattr(self.propMap, '_%sMap'%prop.name, None)
return property(fget=fget)
def _getModelDefaultProperty(self):
prop = self
def fget(self):
return getattr(self.propMap, '_%sDefault'%prop.name, prop.defaultVal)
return property(fget=fget)
class PropModel(object):
@@ -150,8 +173,9 @@ class _PropMapMetaClass(type):
for attr in keys:
if isinstance(attrs[attr], Property):
attrs[attr].name = attr
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
attrs[attr + 'Index'] = attrs[attr]._getIndexProperty()
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
attrs[attr + 'Default'] = attrs[attr]._getDefaultProperty()
attrs[attr + 'Index' ] = attrs[attr]._getIndexProperty()
_properties[attr] = attrs[attr]
attrs.pop(attr)
@@ -181,11 +205,12 @@ class _PropMapMetaClass(type):
for attr in _properties:
prop = _properties[attr]
attrs[attr ] = prop._getProperty()
attrs[attr + 'Map' ] = prop._getModelMapProperty()
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
attrs[attr + 'Model'] = prop._getModelProperty()
attrs[attr + 'Deriv'] = prop._getModelDerivProperty()
attrs[attr ] = prop._getProperty()
attrs[attr + 'Map' ] = prop._getModelMapProperty()
attrs[attr + 'Default'] = prop._getModelDefaultProperty()
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
attrs[attr + 'Model' ] = prop._getModelProperty()
attrs[attr + 'Deriv' ] = prop._getModelDerivProperty()
return type(name.replace('PropMap', 'PropModel'), (PropModel, ), attrs)
@@ -198,8 +223,8 @@ class PropMap(object):
PropMap takes a multi parameter model and maps it to the equivalent PropModel
"""
if type(mappings) is dict:
assert np.all([k in ['maps', 'slices'] for k in mappings]), 'Dict must only have properties "maps" and "slices"'
self.setup(mappings['maps'], slices=mappings['slices'])
assert np.all([k in ['maps', 'slices', 'defaults'] for k in mappings]), 'Dict must only have properties "maps", "slices" and "defaults"'
self.setup(mappings['maps'], slices=mappings.get('slices',{}), defaults=mappings.get('defaults',{}))
elif type(mappings) is list:
self.setup(mappings)
elif isinstance(mappings, Maps.IdentityMap):
@@ -208,7 +233,7 @@ class PropMap(object):
raise Exception('mappings must be a dict, a mapping, or a list of tuples.')
def setup(self, maps, slices=None):
def setup(self, maps, slices=None, defaults=None):
"""
Sets up the maps and slices for the PropertyMap
@@ -231,6 +256,13 @@ class PropMap(object):
s in self._properties and
(type(slices[s]) in [slice, list] or isinstance(slices[s], np.ndarray))
for s in slices]), 'Slices must be for each property'
if defaults is None:
defaults = dict()
else:
assert np.all([
s in self._properties and
(np.isscalar(defaults[s]) or isinstance(defaults[s], np.ndarray))
for s in defaults]), 'Defaults must be for each property'
self.clearMaps()
@@ -239,7 +271,12 @@ class PropMap(object):
setattr(self, '%sMap'%name, mapping)
setattr(self, '%sIndex'%name, slices.get(name, slice(nP, nP + mapping.nP)))
nP += mapping.nP
self.nP = nP
self.nP = nP
for key in defaults:
setattr(self, '%sDefault'%key, defaults[key])
@property
def defaultInvProp(self):
+14
View File
@@ -56,8 +56,22 @@ class TestPropMaps(unittest.TestCase):
assert np.all(m.sigma == np.exp(np.r_[1.,2,3]))
assert m.sigmaDeriv is not None
assert m.mu == mu_0
assert m.nP == 3
def test_defaultOverride(self):
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
PM = MyReciprocalPropMap({'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2}})
self.assertRaises(Exception, MyReciprocalPropMap, {'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2, 'mui':5}}) # Cannot set both sides of the default
m = PM(np.r_[1.,2,3])
assert np.all(m.sigmaModel == np.r_[1,2,3])
self.assertEqual(m.mu, mu_0 * 2)
# self.assertEqual(m.mui, 1/(mu_0 * 2))
def test_slices(self):
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
PM = MyPropMap({'maps':[('sigma', expMap)], 'slices':{'sigma':[2,1,0]}})