mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-30 21:11:51 +08:00
Changed CylModel --> Vertical1DModel (not tested on Cyl1DMesh)
This commit is contained in:
+41
-6
@@ -1,5 +1,5 @@
|
||||
import Utils, Parameters, numpy as np, scipy.sparse as sp
|
||||
|
||||
from Tests import checkDerivative
|
||||
|
||||
class BaseModel(object):
|
||||
"""
|
||||
@@ -55,9 +55,13 @@ class BaseModel(object):
|
||||
"""Number of parameters in the model."""
|
||||
return self.mesh.nC
|
||||
|
||||
def example(self, modelType=None):
|
||||
return np.random.rand(self.mesh.nC)
|
||||
def example(self):
|
||||
return np.random.rand(self.nP)
|
||||
|
||||
def test(self):
|
||||
print 'Testing the %s Class!' % self.__class__.__name__
|
||||
m = self.example()
|
||||
return checkDerivative(lambda m : [self.transform(m), self.transformDeriv(m)], m, plotIt=False)
|
||||
|
||||
|
||||
class LogModel(BaseModel):
|
||||
@@ -128,16 +132,47 @@ class LogModel(BaseModel):
|
||||
"""
|
||||
return Utils.sdiag(np.exp(Utils.mkvc(m)))
|
||||
|
||||
class CylModel(BaseModel):
|
||||
"""SimPEG LogModel"""
|
||||
class Vertical1DModel(BaseModel):
|
||||
"""Vertical1DModel
|
||||
|
||||
Given a 1D vector through the last dimension
|
||||
of the mesh, this will extend to the full
|
||||
model space.
|
||||
"""
|
||||
|
||||
def __init__(self, mesh, **kwargs):
|
||||
BaseModel.__init__(self, mesh, **kwargs)
|
||||
|
||||
@property
|
||||
def nP(self):
|
||||
"""The number of cells in the
|
||||
last dimension of the mesh."""
|
||||
return self.mesh.nCv[self.mesh.dim-1]
|
||||
|
||||
def transform(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: numpy.array
|
||||
:return: transformed model
|
||||
"""
|
||||
return m.repeat(self.mesh.nCx)
|
||||
repNum = self.mesh.nCv[:self.mesh.dim-2].prod()
|
||||
return Utils.mkvc(m).repeat(repNum)
|
||||
|
||||
def transformDeriv(self, m):
|
||||
"""
|
||||
:param numpy.array m: model
|
||||
:rtype: scipy.csr_matrix
|
||||
:return: derivative of transformed model
|
||||
"""
|
||||
repNum = self.mesh.nCv[:self.mesh.dim-2].prod()
|
||||
repVec = sp.csr_matrix(
|
||||
(np.ones(repNum),
|
||||
(range(repNum), np.zeros(repNum))
|
||||
), shape=(repNum, 1))
|
||||
return sp.kron(repVec, sp.identity(self.nP))
|
||||
|
||||
if __name__ == '__main__':
|
||||
from SimPEG import *
|
||||
mesh = Mesh.TensorMesh([10,8])
|
||||
model = BaseModel(mesh)
|
||||
model.test()
|
||||
|
||||
@@ -11,17 +11,16 @@ class ModelTests(unittest.TestCase):
|
||||
|
||||
a = np.array([1, 1, 1])
|
||||
b = np.array([1, 2])
|
||||
c = np.array([1, 4])
|
||||
self.mesh2 = Mesh.TensorMesh([a, b], np.array([3, 5]))
|
||||
|
||||
def test_modelTransforms(self):
|
||||
print 'SimPEG.Model.BaseModel: Testing Model Transform'
|
||||
for M in dir(Model):
|
||||
if 'Model' not in M: continue
|
||||
model = getattr(Model, M)(self.mesh2)
|
||||
m = model.example()
|
||||
passed = checkDerivative(lambda m : [model.transform(m), model.transformDeriv(m)], m, plotIt=False)
|
||||
self.assertTrue(passed)
|
||||
try:
|
||||
model = getattr(Model, M)(self.mesh2)
|
||||
assert isinstance(model, Model.BaseModel)
|
||||
except Exception, e:
|
||||
continue
|
||||
self.assertTrue(model.test())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user