diff --git a/SimPEG/Model.py b/SimPEG/Model.py index 3ef96036..2438490f 100644 --- a/SimPEG/Model.py +++ b/SimPEG/Model.py @@ -1,5 +1,5 @@ import Utils, Parameters, numpy as np, scipy.sparse as sp - +from Tests import checkDerivative class BaseModel(object): """ @@ -55,9 +55,13 @@ class BaseModel(object): """Number of parameters in the model.""" return self.mesh.nC - def example(self, modelType=None): - return np.random.rand(self.mesh.nC) + def example(self): + return np.random.rand(self.nP) + def test(self): + print 'Testing the %s Class!' % self.__class__.__name__ + m = self.example() + return checkDerivative(lambda m : [self.transform(m), self.transformDeriv(m)], m, plotIt=False) class LogModel(BaseModel): @@ -128,16 +132,47 @@ class LogModel(BaseModel): """ return Utils.sdiag(np.exp(Utils.mkvc(m))) -class CylModel(BaseModel): - """SimPEG LogModel""" +class Vertical1DModel(BaseModel): + """Vertical1DModel + + Given a 1D vector through the last dimension + of the mesh, this will extend to the full + model space. + """ def __init__(self, mesh, **kwargs): BaseModel.__init__(self, mesh, **kwargs) + @property + def nP(self): + """The number of cells in the + last dimension of the mesh.""" + return self.mesh.nCv[self.mesh.dim-1] + def transform(self, m): """ :param numpy.array m: model :rtype: numpy.array :return: transformed model """ - return m.repeat(self.mesh.nCx) + repNum = self.mesh.nCv[:self.mesh.dim-2].prod() + return Utils.mkvc(m).repeat(repNum) + + def transformDeriv(self, m): + """ + :param numpy.array m: model + :rtype: scipy.csr_matrix + :return: derivative of transformed model + """ + repNum = self.mesh.nCv[:self.mesh.dim-2].prod() + repVec = sp.csr_matrix( + (np.ones(repNum), + (range(repNum), np.zeros(repNum)) + ), shape=(repNum, 1)) + return sp.kron(repVec, sp.identity(self.nP)) + +if __name__ == '__main__': + from SimPEG import * + mesh = Mesh.TensorMesh([10,8]) + model = BaseModel(mesh) + model.test() diff --git a/SimPEG/Tests/test_model.py b/SimPEG/Tests/test_model.py index 0c3b2c54..df957ecc 100644 --- a/SimPEG/Tests/test_model.py +++ b/SimPEG/Tests/test_model.py @@ -11,17 +11,16 @@ class ModelTests(unittest.TestCase): a = np.array([1, 1, 1]) b = np.array([1, 2]) - c = np.array([1, 4]) self.mesh2 = Mesh.TensorMesh([a, b], np.array([3, 5])) def test_modelTransforms(self): - print 'SimPEG.Model.BaseModel: Testing Model Transform' for M in dir(Model): - if 'Model' not in M: continue - model = getattr(Model, M)(self.mesh2) - m = model.example() - passed = checkDerivative(lambda m : [model.transform(m), model.transformDeriv(m)], m, plotIt=False) - self.assertTrue(passed) + try: + model = getattr(Model, M)(self.mesh2) + assert isinstance(model, Model.BaseModel) + except Exception, e: + continue + self.assertTrue(model.test()) if __name__ == '__main__': unittest.main()