Merge branch 'develop' of https://github.com/simpeg/simpeg into develop

This commit is contained in:
Dave Marchant
2014-02-05 21:51:16 -08:00
3 changed files with 50 additions and 13 deletions
+3
View File
@@ -37,6 +37,9 @@ class Cyl1DMesh(object):
return locals()
h = property(**h())
@property
def dim(self): return 2
def z0():
doc = "The z-origin"
def fget(self):
+41 -6
View File
@@ -1,5 +1,5 @@
import Utils, Parameters, numpy as np, scipy.sparse as sp
from Tests import checkDerivative
class BaseModel(object):
"""
@@ -55,9 +55,13 @@ class BaseModel(object):
"""Number of parameters in the model."""
return self.mesh.nC
def example(self, modelType=None):
return np.random.rand(self.mesh.nC)
def example(self):
return np.random.rand(self.nP)
def test(self):
print 'Testing the %s Class!' % self.__class__.__name__
m = self.example()
return checkDerivative(lambda m : [self.transform(m), self.transformDeriv(m)], m, plotIt=False)
class LogModel(BaseModel):
@@ -128,16 +132,47 @@ class LogModel(BaseModel):
"""
return Utils.sdiag(np.exp(Utils.mkvc(m)))
class CylModel(BaseModel):
"""SimPEG LogModel"""
class Vertical1DModel(BaseModel):
"""Vertical1DModel
Given a 1D vector through the last dimension
of the mesh, this will extend to the full
model space.
"""
def __init__(self, mesh, **kwargs):
BaseModel.__init__(self, mesh, **kwargs)
@property
def nP(self):
"""The number of cells in the
last dimension of the mesh."""
return self.mesh.nCv[self.mesh.dim-1]
def transform(self, m):
"""
:param numpy.array m: model
:rtype: numpy.array
:return: transformed model
"""
return m.repeat(self.mesh.nCx)
repNum = self.mesh.nCv[:self.mesh.dim-2].prod()
return Utils.mkvc(m).repeat(repNum)
def transformDeriv(self, m):
"""
:param numpy.array m: model
:rtype: scipy.csr_matrix
:return: derivative of transformed model
"""
repNum = self.mesh.nCv[:self.mesh.dim-2].prod()
repVec = sp.csr_matrix(
(np.ones(repNum),
(range(repNum), np.zeros(repNum))
), shape=(repNum, 1))
return sp.kron(repVec, sp.identity(self.nP))
if __name__ == '__main__':
from SimPEG import *
mesh = Mesh.TensorMesh([10,8])
model = BaseModel(mesh)
model.test()
+6 -7
View File
@@ -11,17 +11,16 @@ class ModelTests(unittest.TestCase):
a = np.array([1, 1, 1])
b = np.array([1, 2])
c = np.array([1, 4])
self.mesh2 = Mesh.TensorMesh([a, b], np.array([3, 5]))
def test_modelTransforms(self):
print 'SimPEG.Model.BaseModel: Testing Model Transform'
for M in dir(Model):
if 'Model' not in M: continue
model = getattr(Model, M)(self.mesh2)
m = model.example()
passed = checkDerivative(lambda m : [model.transform(m), model.transformDeriv(m)], m, plotIt=False)
self.assertTrue(passed)
try:
model = getattr(Model, M)(self.mesh2)
assert isinstance(model, Model.BaseModel)
except Exception, e:
continue
self.assertTrue(model.test())
if __name__ == '__main__':
unittest.main()