From 4776544b301459e25834408ef82912a54814a573 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Mon, 15 Feb 2016 10:51:43 -0800 Subject: [PATCH] Initial work on defaults. --- SimPEG/PropMaps.py | 63 +++++++++++++++++++++++++++++-------- tests/base/test_PropMaps.py | 14 +++++++++ 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/SimPEG/PropMaps.py b/SimPEG/PropMaps.py index 527a6f7e..92aa7890 100644 --- a/SimPEG/PropMaps.py +++ b/SimPEG/PropMaps.py @@ -34,6 +34,18 @@ class Property(object): setattr(self, '_%sMap'%prop.name, val) return property(fget=fget, fset=fset, doc=prop.doc) + def _getDefaultProperty(self): + prop = self + def fget(self): + return getattr(self, '_%sDefault'%prop.name, None) + def fset(self, val): + if prop.propertyLink is not None: + linkName, linkMap = prop.propertyLink + assert getattr(self, '%sDefault'%linkName, None) is None, 'Cannot set both sides of a linked property.' + assert isinstance(val, np.ndarray) or np.isscalar(val), 'Default must be a scalar or a numpy array.' + setattr(self, '_%sDefault'%prop.name, val) + return property(fget=fget, fset=fset, doc=prop.doc) + def _getIndexProperty(self): prop = self def fget(self): @@ -47,12 +59,17 @@ class Property(object): def fget(self): mapping = getattr(self, '%sMap'%prop.name) if mapping is None and prop.propertyLink is None: - return prop.defaultVal + return getattr(self, '%sDefault'%prop.name) if mapping is None and prop.propertyLink is not None: linkName, linkMapClass = prop.propertyLink linkMap = linkMapClass(None) - if getattr(self, '%sMap'%linkName, None) is None: + # * + print linkName, getattr(self.propMap, '_%sDefault'%linkName, None) + if getattr(self, '%sMap'%linkName, None) is None and getattr(self.propMap, '_%sDefault'%linkName, None) is not None: + # We have a default + return linkMap * getattr(self, '%sDefault'%linkName, None) + elif getattr(self, '%sMap'%linkName, None) is None: return prop.defaultVal m = getattr(self, '%s'%linkName) return linkMap * m @@ -110,6 +127,12 @@ class Property(object): return getattr(self.propMap, '_%sMap'%prop.name, None) return property(fget=fget) + def _getModelDefaultProperty(self): + prop = self + def fget(self): + return getattr(self.propMap, '_%sDefault'%prop.name, prop.defaultVal) + return property(fget=fget) + class PropModel(object): @@ -150,8 +173,9 @@ class _PropMapMetaClass(type): for attr in keys: if isinstance(attrs[attr], Property): attrs[attr].name = attr - attrs[attr + 'Map' ] = attrs[attr]._getMapProperty() - attrs[attr + 'Index'] = attrs[attr]._getIndexProperty() + attrs[attr + 'Map' ] = attrs[attr]._getMapProperty() + attrs[attr + 'Default'] = attrs[attr]._getDefaultProperty() + attrs[attr + 'Index' ] = attrs[attr]._getIndexProperty() _properties[attr] = attrs[attr] attrs.pop(attr) @@ -181,11 +205,12 @@ class _PropMapMetaClass(type): for attr in _properties: prop = _properties[attr] - attrs[attr ] = prop._getProperty() - attrs[attr + 'Map' ] = prop._getModelMapProperty() - attrs[attr + 'Proj' ] = prop._getModelProjProperty() - attrs[attr + 'Model'] = prop._getModelProperty() - attrs[attr + 'Deriv'] = prop._getModelDerivProperty() + attrs[attr ] = prop._getProperty() + attrs[attr + 'Map' ] = prop._getModelMapProperty() + attrs[attr + 'Default'] = prop._getModelDefaultProperty() + attrs[attr + 'Proj' ] = prop._getModelProjProperty() + attrs[attr + 'Model' ] = prop._getModelProperty() + attrs[attr + 'Deriv' ] = prop._getModelDerivProperty() return type(name.replace('PropMap', 'PropModel'), (PropModel, ), attrs) @@ -198,8 +223,8 @@ class PropMap(object): PropMap takes a multi parameter model and maps it to the equivalent PropModel """ if type(mappings) is dict: - assert np.all([k in ['maps', 'slices'] for k in mappings]), 'Dict must only have properties "maps" and "slices"' - self.setup(mappings['maps'], slices=mappings['slices']) + assert np.all([k in ['maps', 'slices', 'defaults'] for k in mappings]), 'Dict must only have properties "maps", "slices" and "defaults"' + self.setup(mappings['maps'], slices=mappings.get('slices',{}), defaults=mappings.get('defaults',{})) elif type(mappings) is list: self.setup(mappings) elif isinstance(mappings, Maps.IdentityMap): @@ -208,7 +233,7 @@ class PropMap(object): raise Exception('mappings must be a dict, a mapping, or a list of tuples.') - def setup(self, maps, slices=None): + def setup(self, maps, slices=None, defaults=None): """ Sets up the maps and slices for the PropertyMap @@ -231,6 +256,13 @@ class PropMap(object): s in self._properties and (type(slices[s]) in [slice, list] or isinstance(slices[s], np.ndarray)) for s in slices]), 'Slices must be for each property' + if defaults is None: + defaults = dict() + else: + assert np.all([ + s in self._properties and + (np.isscalar(defaults[s]) or isinstance(defaults[s], np.ndarray)) + for s in defaults]), 'Defaults must be for each property' self.clearMaps() @@ -239,7 +271,12 @@ class PropMap(object): setattr(self, '%sMap'%name, mapping) setattr(self, '%sIndex'%name, slices.get(name, slice(nP, nP + mapping.nP))) nP += mapping.nP - self.nP = nP + self.nP = nP + + for key in defaults: + setattr(self, '%sDefault'%key, defaults[key]) + + @property def defaultInvProp(self): diff --git a/tests/base/test_PropMaps.py b/tests/base/test_PropMaps.py index ef22aaad..a234cd69 100644 --- a/tests/base/test_PropMaps.py +++ b/tests/base/test_PropMaps.py @@ -56,8 +56,22 @@ class TestPropMaps(unittest.TestCase): assert np.all(m.sigma == np.exp(np.r_[1.,2,3])) assert m.sigmaDeriv is not None + assert m.mu == mu_0 + assert m.nP == 3 + def test_defaultOverride(self): + expMap = Maps.ExpMap(Mesh.TensorMesh((3,))) + PM = MyReciprocalPropMap({'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2}}) + self.assertRaises(Exception, MyReciprocalPropMap, {'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2, 'mui':5}}) # Cannot set both sides of the default + + m = PM(np.r_[1.,2,3]) + assert np.all(m.sigmaModel == np.r_[1,2,3]) + + self.assertEqual(m.mu, mu_0 * 2) + # self.assertEqual(m.mui, 1/(mu_0 * 2)) + + def test_slices(self): expMap = Maps.ExpMap(Mesh.TensorMesh((3,))) PM = MyPropMap({'maps':[('sigma', expMap)], 'slices':{'sigma':[2,1,0]}})