diff --git a/SimPEG/Directives.py b/SimPEG/Directives.py index 0a98bc09..46b0b087 100644 --- a/SimPEG/Directives.py +++ b/SimPEG/Directives.py @@ -285,45 +285,43 @@ class SaveOutputDictEveryIteration(_SaveEveryIteration): class update_IRLS(InversionDirective): - m = None - eps_min = None - factor = None - gamma = None - phi_m_last = None + eps_min = None + factor = None + gamma = None + phi_m_last = None - def initialize(self): + def initialize(self): - # Scale the regularization for changes in norm - if getattr(self, 'phi_m_last', None) is not None: - self.reg.gamma = 1. - phim_new = self.reg.eval(self.invProb.curModel) - self.gamma = self.phi_m_last / phim_new + # Scale the regularization for changes in norm + if getattr(self, 'phi_m_last', None) is not None: + self.reg.gamma = 1. + phim_new = self.reg.eval(self.invProb.curModel) + self.gamma = self.phi_m_last / phim_new - self.reg.gamma = self.gamma + self.reg.gamma = self.gamma - def endIter(self): - # Cool the threshold parameter - if getattr(self, 'factor', None) is not None: - eps = self.reg.eps / self.factor + def endIter(self): + # Cool the threshold parameter + if getattr(self, 'factor', None) is not None: + eps = self.reg.eps / self.factor - if getattr(self, 'eps_min', None) is not None: - self.reg.eps = np.max([self.eps_min,eps]) - else: - self.reg.eps = eps + if getattr(self, 'eps_min', None) is not None: + self.reg.eps = np.max([self.eps_min,eps]) + else: + self.reg.eps = eps # Update the model used for the IRLS weights - if getattr(self, 'm', None) is None: - self.reg.m = self.invProb.curModel + self.reg.curModel = self.invProb.curModel - # Update the pre-conditioner - diagA = np.sum(self.prob.G**2.,axis=0) + self.invProb.beta*(self.reg.W.T*self.reg.W).diagonal() * (self.reg.mapping * np.ones(self.reg.m.size))**2. - PC = Utils.sdiag(diagA**-1.) + # Update the pre-conditioner + diagA = np.sum(self.prob.G**2.,axis=0) + self.invProb.beta*(self.reg.W.T*self.reg.W).diagonal() * (self.reg.mapping * np.ones(self.reg.m.size))**2. + PC = Utils.sdiag(diagA**-1.) - self.opt.approxHinv = PC + self.opt.approxHinv = PC - phim_new = self.reg.eval(self.invProb.curModel) - self.reg.gamma = self.reg.gamma * self.invProb.phi_m_last / phim_new + phim_new = self.reg.eval(self.invProb.curModel) + self.reg.gamma = self.reg.gamma * self.invProb.phi_m_last / phim_new #============================================================================== # import pylab as plt diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index db1bc39b..6ad94caf 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -6,7 +6,7 @@ class RegularizationMesh(object): This contains the operators used in the regularization. Note that these are not necessarily true differential operators, but are constructed from - a SimPEG Mesh. + a SimPEG Mesh. :param Mesh mesh: problem mesh :param numpy.array indActive: bool array, size nC, that is True where we have active cells. Used to reduce the operators so we regularize only on active cells @@ -52,7 +52,7 @@ class RegularizationMesh(object): if getattr(self, '_dim', None) is None: self._dim = self.mesh.dim return self._dim - + @property def _Pac(self): @@ -64,7 +64,7 @@ class RegularizationMesh(object): if getattr(self, '__Pac', None) is None: if self.indActive is None: self.__Pac = Utils.speye(self.mesh.nC) - else: + else: self.__Pac = Utils.speye(self.mesh.nC)[:,self.indActive] return self.__Pac @@ -211,7 +211,7 @@ class RegularizationMesh(object): if getattr(self, '_cellDiffz', None) is None: self._cellDiffz = self._Pafz.T * self.mesh.cellGradz * self._Pac return self._cellDiffz - + @property def faceDiffx(self): """ @@ -233,7 +233,7 @@ class RegularizationMesh(object): if getattr(self, '_faceDiffy', None) is None: self._faceDiffy = self._Pac.T * self.mesh.faceDivy * self._Pafy return self._faceDiffy - + @property def faceDiffz(self): """ @@ -310,7 +310,7 @@ class BaseRegularization(object): if indActive is not None and indActive.dtype != 'bool': tmp = indActive indActive = np.zeros(mesh.nC, dtype=bool) - indActive[tmp] = True + indActive[tmp] = True self.regmesh = RegularizationMesh(mesh,indActive) self.mapping = mapping or self.mapPair(mesh) self.mapping._assertMatchesPair(self.mapPair) @@ -427,7 +427,7 @@ class Tikhonov(BaseRegularization): """Regularization matrix Wx""" if getattr(self, '_Wx', None) is None: Ave_x_vol = self.regmesh.aveCC2Fx * self.regmesh.vol - self._Wx = Utils.sdiag((Ave_x_vol*self.alpha_x)**0.5)*self.regmesh.cellDiffx + self._Wx = Utils.sdiag((Ave_x_vol*self.alpha_x)**0.5)*self.regmesh.cellDiffx return self._Wx @property @@ -640,13 +640,14 @@ class Simple(BaseRegularization): class Sparse(Simple): - eps = 1e-1 - m = None - gamma = 1. - p = 0. - qx = 2. - qy = 2. - qz = 2. + # set default values + eps = 1e-1 + curModel = None # use a model to compute the weights + gamma = 1. + p = 0. + qx = 2. + qy = 2. + qz = 2. def __init__(self, mesh, mapping=None, indActive=None, **kwargs): Simple.__init__(self, mesh, mapping=mapping, indActive=indActive, **kwargs) @@ -655,71 +656,64 @@ class Sparse(Simple): @property def Ws(self): """Regularization matrix Ws""" - if getattr(self, 'm', None) is None: + if getattr(self, 'curModel', None) is None: self.Rs = Utils.speye(self.regmesh.nC) else: - f_m = self.m + f_m = self.curModel self.rs = self.R(f_m , self.p, self.eps) #print "Min rs: " + str(np.max(self.rs)) + "Max rs: " + str(np.min(self.rs)) self.Rs = Utils.sdiag( self.rs ) - self._Ws = Utils.sdiag((self.regmesh.vol*self.alpha_s*self.gamma)**0.5)*self.Rs + return Utils.sdiag((self.regmesh.vol*self.alpha_s*self.gamma)**0.5)*self.Rs - return self._Ws @property def Wx(self): """Regularization matrix Wx""" - if getattr(self, 'm', None) is None: + if getattr(self, 'curModel', None) is None: self.Rx = Utils.speye(self.regmesh.cellDiffxStencil.shape[0]) else: - f_m = self.regmesh.cellDiffxStencil * self.m + f_m = self.regmesh.cellDiffxStencil * self.curModel self.rx = self.R( f_m , self.qx, self.eps) self.Rx = Utils.sdiag( self.rx ) - if getattr(self, '_Wx', None) is None: - self._Wx = Utils.sdiag(( (self.regmesh.aveCC2Fx * self.regmesh.vol) *self.alpha_x*self.gamma)**0.5)*self.Rx*self.regmesh.cellDiffxStencil - return self._Wx + return Utils.sdiag(( (self.regmesh.aveCC2Fx * self.regmesh.vol) *self.alpha_x*self.gamma)**0.5)*self.Rx*self.regmesh.cellDiffxStencil @property def Wy(self): """Regularization matrix Wy""" - if getattr(self, 'm', None) is None: + if getattr(self, 'curModel', None) is None: self.Ry = Utils.speye(self.regmesh.cellDiffyStencil.shape[0]) else: - f_m = self.regmesh.cellDiffyStencil * self.m + f_m = self.regmesh.cellDiffyStencil * self.curModel self.ry = self.R( f_m , self.qy, self.eps) self.Ry = Utils.sdiag( self.ry ) - if getattr(self, '_Wy', None) is None: - self._Wy = Utils.sdiag(((self.regmesh.aveCC2Fy * self.regmesh.vol)*self.alpha_y*self.gamma)**0.5)*self.Ry*self.regmesh.cellDiffyStencil - return self._Wy + return Utils.sdiag(((self.regmesh.aveCC2Fy * self.regmesh.vol)*self.alpha_y*self.gamma)**0.5)*self.Ry*self.regmesh.cellDiffyStencil @property def Wz(self): """Regularization matrix Wz""" - if getattr(self, 'm', None) is None: + if getattr(self, 'curModel', None) is None: self.Rz = Utils.speye(self.regmesh.cellDiffzStencil.shape[0]) else: - f_m = self.regmesh.cellDiffzStencil * self.m + f_m = self.regmesh.cellDiffzStencil * self.curModel self.rz = self.R( f_m , self.qz, self.eps) self.Rz = Utils.sdiag( self.rz ) - if getattr(self, '_Wz', None) is None: - self._Wz = Utils.sdiag(((self.regmesh.aveCC2Fz * self.regmesh.vol)*self.alpha_z*self.gamma)**0.5)*self.Rz*self.regmesh.cellDiffzStencil - return self._Wz + return Utils.sdiag(((self.regmesh.aveCC2Fz * self.regmesh.vol)*self.alpha_z*self.gamma)**0.5)*self.Rz*self.regmesh.cellDiffzStencil - def R(self, f_m , p, dec): + def R(self, f_m , exponent): - eta = (self.eps**(1-p/2.))**0.5 - r = eta / (f_m**2.+self.eps**2.)**((1-p/2.)/2.) + eta = (self.eps**(1-exponent/2.))**0.5 + r = eta / (f_m**2.+self.eps**2.)**((1-exponent/2.)/2.) return r diff --git a/tests/base/test_regularization.py b/tests/base/test_regularization.py index 52cef349..97223015 100644 --- a/tests/base/test_regularization.py +++ b/tests/base/test_regularization.py @@ -18,7 +18,7 @@ class RegularizationTests(unittest.TestCase): mesh3 = Mesh.TensorMesh([hx, hy, hz]) self.meshlist = [mesh1,mesh2, mesh3] - if testReg: + if testReg: def test_regularization(self): for R in dir(Regularization): r = getattr(Regularization, R)