From f10d6cc2b7602ea990ea37524249e64825471327 Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Tue, 15 Apr 2014 09:47:01 -0700 Subject: [PATCH] Regularization tests --- SimPEG/Regularization.py | 8 ++++---- SimPEG/Tests/test_regularization.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/SimPEG/Regularization.py b/SimPEG/Regularization.py index cc78df36..68e528e3 100644 --- a/SimPEG/Regularization.py +++ b/SimPEG/Regularization.py @@ -25,7 +25,7 @@ class BaseRegularization(object): Utils.setKwargs(self, **kwargs) self.mesh = mesh self.mapping = mapping or Maps.IdentityMap(mesh) - self.mapping._assertMatchesPair(mapPair) + self.mapping._assertMatchesPair(self.mapPair) mref = Parameters.ParameterProperty('mref', default=None, doc='Reference model.') @@ -56,7 +56,7 @@ class BaseRegularization(object): @property def W(self): """Full regularization weighting matrix W.""" - return sp.identity(self.model.nP) + return sp.identity(self.mapping.nP) @Utils.timeIt @@ -206,8 +206,8 @@ class Tikhonov(BaseRegularization): alpha_yy = Utils.dependentProperty('_alpha_yy', 0.0, ['_W', '_Wyy'], "Weight for the second derivative in the y direction") alpha_zz = Utils.dependentProperty('_alpha_zz', 0.0, ['_W', '_Wzz'], "Weight for the second derivative in the z direction") - def __init__(self, model, **kwargs): - BaseRegularization.__init__(self, model, **kwargs) + def __init__(self, mesh, mapping=None, **kwargs): + BaseRegularization.__init__(self, mesh, mapping=mapping, **kwargs) @property def Ws(self): diff --git a/SimPEG/Tests/test_regularization.py b/SimPEG/Tests/test_regularization.py index 3df9180c..71265456 100644 --- a/SimPEG/Tests/test_regularization.py +++ b/SimPEG/Tests/test_regularization.py @@ -19,10 +19,10 @@ class RegularizationTests(unittest.TestCase): continue # if 'Regularization' not in R: continue print 'Check:', R - model = r.modelPair(self.mesh2) - reg = r(model) - m = model.example() - reg.mref = model.example()*0 + mapping = r.mapPair(self.mesh2) + reg = r(self.mesh2, mapping=mapping) + m = mapping.example() + reg.mref = mapping.example()*0 passed = checkDerivative(lambda m : [reg.modelObj(m), reg.modelObjDeriv(m)], m, plotIt=False) self.assertTrue(passed)