From 583f3ed8d0bc709b353e2b7c65df5ae762afbf2b Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Sun, 19 Jan 2014 11:51:15 -0700 Subject: [PATCH] Added ParameterProperty, which handles the logic for calling the parameter each iteration of the optimization. --- SimPEG/ObjFunction.py | 11 +++++------ SimPEG/Optimization.py | 2 +- SimPEG/Utils/__init__.py | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/SimPEG/ObjFunction.py b/SimPEG/ObjFunction.py index c9578c54..d88d5740 100644 --- a/SimPEG/ObjFunction.py +++ b/SimPEG/ObjFunction.py @@ -5,7 +5,8 @@ class BaseObjFunction(object): __metaclass__ = Utils.Save.Savable - beta = None #: Regularization trade-off parameter + beta = Utils.ParameterProperty('beta', default=None, doc='Regularization trade-off parameter') + debug = False #: Print debugging information counter = None #: Set this to a SimPEG.Utils.Counter() if you want to count things @@ -68,16 +69,14 @@ class BaseObjFunction(object): self.phi_d, self.phi_d_last = phi_d, self.phi_d self.phi_m, self.phi_m_last = phi_m, self.phi_m - self._beta = self.beta.get() #TODO: This needs to be fixed. - - f = phi_d + self._beta * phi_m + f = phi_d + self.beta * phi_m out = (f,) if return_g: phi_dDeriv = self.dataObjDeriv(m, u=u) phi_mDeriv = self.reg.modelObjDeriv(m) - g = phi_dDeriv + self._beta * phi_mDeriv + g = phi_dDeriv + self.beta * phi_mDeriv out += (g,) if return_H: @@ -85,7 +84,7 @@ class BaseObjFunction(object): phi_d2Deriv = self.dataObj2Deriv(m, v, u=u) phi_m2Deriv = self.reg.modelObj2Deriv()*v - return phi_d2Deriv + self._beta * phi_m2Deriv + return phi_d2Deriv + self.beta * phi_m2Deriv operator = sp.linalg.LinearOperator( (m.size, m.size), H_fun, dtype=m.dtype ) out += (operator,) diff --git a/SimPEG/Optimization.py b/SimPEG/Optimization.py index afb2b2a2..35b71160 100644 --- a/SimPEG/Optimization.py +++ b/SimPEG/Optimization.py @@ -71,7 +71,7 @@ class IterationPrinters(object): bSet = {"title": "bSet", "value": lambda M: np.sum(M.bindingSet(M.xc)), "width": 8, "format": "%d"} comment = {"title": "Comment", "value": lambda M: M.comment, "width": 12, "format": "%s"} - beta = {"title": "beta", "value": lambda M: M.parent.objFunc.beta.get(), "width": 10, "format": "%1.2e"} + beta = {"title": "beta", "value": lambda M: M.parent.objFunc.beta, "width": 10, "format": "%1.2e"} phi_d = {"title": "phi_d", "value": lambda M: M.parent.objFunc.phi_d, "width": 10, "format": "%1.2e"} phi_m = {"title": "phi_m", "value": lambda M: M.parent.objFunc.phi_m, "width": 10, "format": "%1.2e"} diff --git a/SimPEG/Utils/__init__.py b/SimPEG/Utils/__init__.py index f35ee5a0..8c0f899b 100644 --- a/SimPEG/Utils/__init__.py +++ b/SimPEG/Utils/__init__.py @@ -284,14 +284,32 @@ class Parameter(object): return getattr(self,'_parent',None) @parent.setter def parent(self, p): + startupName = '_startup_paramProperty_'+self._propertyName if getattr(self,'_parent',None) is not None: - print 'Parameter has switched to a new parent!' + delattr(self._parent,startupName) + print 'Warning: Parameter %s has switched to a new parent.' % self._propertyName + if self.debug: print '%s function has been deleted' % startupName self._parent = p + prop = self + def _startup_paramProperty(self, *args): + if prop.debug: print 'initializing %s' % prop._propertyName + prop.initialize() + + hook(self._parent, _startup_paramProperty, name=startupName, overwrite=True) + @property def opt(self): return self.parent.parent.opt + @property + def objFunc(self): + return self.parent + + @property + def reg(self): + return self.parent.reg + def initialize(self): pass @@ -306,6 +324,20 @@ class Parameter(object): raise NotImplementedError('Getting the Parameter is not yet implemented.') +def ParameterProperty(name, default=None, doc=""): + def getter(self): + out = getattr(self,'_'+name,default) + if isinstance(out, Parameter): + out = out.get() + return out + def setter(self, value): + if isinstance(value, Parameter): + value._propertyName = name + value.parent = self + setattr(self, '_'+name, value) + + return property(fget=getter, fset=setter, doc=doc) + if __name__ == '__main__': class MyClass(object):