mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-04 16:15:05 +08:00
Clean up imports in DC, move data to forward folder. clean up init statements.
This commit is contained in:
+1
-1
@@ -3,11 +3,11 @@ import scipy.sparse as sp
|
||||
import utils
|
||||
from utils import Solver
|
||||
import mesh
|
||||
import data
|
||||
import forward
|
||||
import inverse
|
||||
import visualize
|
||||
import examples
|
||||
import tests
|
||||
|
||||
import scipy.version as _v
|
||||
if _v.version < '0.13.0':
|
||||
|
||||
+19
-29
@@ -1,13 +1,6 @@
|
||||
from SimPEG.mesh import TensorMesh
|
||||
from SimPEG.forward import Problem, ModelTransforms
|
||||
from SimPEG.tests import checkDerivative
|
||||
from SimPEG.utils import ModelBuilder, sdiag, mkvc
|
||||
from SimPEG import Solver
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from SimPEG import *
|
||||
|
||||
|
||||
class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
class DCProblem(forward.ModelTransforms.LogModel, forward.Problem):
|
||||
"""
|
||||
**DCProblem**
|
||||
|
||||
@@ -15,7 +8,7 @@ class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
|
||||
"""
|
||||
def __init__(self, mesh):
|
||||
Problem.__init__(self, mesh)
|
||||
forward.Problem.__init__(self, mesh)
|
||||
self.mesh.setCellGradBC('neumann')
|
||||
|
||||
def reshapeFields(self, u):
|
||||
@@ -55,13 +48,13 @@ class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
|
||||
u = self.reshapeFields(u)
|
||||
|
||||
return mkvc(self.P*u)
|
||||
return utils.mkvc(self.P*u)
|
||||
|
||||
def field(self, m):
|
||||
A = self.createMatrix(m)
|
||||
solve = Solver(A)
|
||||
phi = solve.solve(self.RHS)
|
||||
return mkvc(phi)
|
||||
return utils.mkvc(phi)
|
||||
|
||||
def J(self, m, v, u=None):
|
||||
"""
|
||||
@@ -101,11 +94,11 @@ class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
|
||||
dCdm = np.empty_like(u)
|
||||
for i, ui in enumerate(u.T): # loop over each column
|
||||
dCdm[:, i] = D * ( sdiag( G * ui ) * ( Av_dm * ( mT_dm * v ) ) )
|
||||
dCdm[:, i] = D * ( utils.sdiag( G * ui ) * ( Av_dm * ( mT_dm * v ) ) )
|
||||
|
||||
solve = Solver(dCdu)
|
||||
Jv = - P * solve.solve(dCdm)
|
||||
return mkvc(Jv)
|
||||
return utils.mkvc(Jv)
|
||||
|
||||
def Jt(self, m, v, u=None):
|
||||
"""Takes data, turns it into a model..ish"""
|
||||
@@ -130,7 +123,7 @@ class DCProblem(ModelTransforms.LogModel, Problem):
|
||||
|
||||
Jtv = 0
|
||||
for i, ui in enumerate(u.T): # loop over each column
|
||||
Jtv += sdiag( G * ui ) * ( D.T * w[:,i] )
|
||||
Jtv += utils.sdiag( G * ui ) * ( D.T * w[:,i] )
|
||||
|
||||
Jtv = - mT_dm.T * ( Av_dm.T * Jtv )
|
||||
return Jtv
|
||||
@@ -165,16 +158,13 @@ def genTxRxmat(nelec, spacelec, surfloc, elecini, mesh):
|
||||
return q, Q, rxmidLoc
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from SimPEG import inverse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Create the mesh
|
||||
h1 = np.ones(20)
|
||||
h2 = np.ones(100)
|
||||
mesh = TensorMesh([h1,h2])
|
||||
M = mesh.TensorMesh([h1,h2])
|
||||
|
||||
# Create some parameters for the model
|
||||
sig1 = np.log(1)
|
||||
@@ -184,8 +174,8 @@ if __name__ == '__main__':
|
||||
p0 = [5, 10]
|
||||
p1 = [15, 50]
|
||||
condVals = [sig1, sig2]
|
||||
mSynth = ModelBuilder.defineBlockConductivity(p0,p1,mesh.gridCC,condVals)
|
||||
plt.colorbar(mesh.plotImage(mSynth))
|
||||
mSynth = utils.ModelBuilder.defineBlockConductivity(p0,p1,M.gridCC,condVals)
|
||||
plt.colorbar(M.plotImage(mSynth))
|
||||
plt.show()
|
||||
|
||||
# Set up the projection
|
||||
@@ -196,32 +186,32 @@ if __name__ == '__main__':
|
||||
elecend = 0.5+spacelec*(nelec-1)
|
||||
elecLocR = np.linspace(elecini, elecend, nelec)
|
||||
rxmidLoc = (elecLocR[0:nelec-1]+elecLocR[1:nelec])*0.5
|
||||
q, Q, rxmidloc = genTxRxmat(nelec, spacelec, surfloc, elecini, mesh)
|
||||
q, Q, rxmidloc = genTxRxmat(nelec, spacelec, surfloc, elecini, M)
|
||||
P = Q.T
|
||||
|
||||
# Create some data
|
||||
problem = DCProblem(mesh)
|
||||
problem = DCProblem(M)
|
||||
problem.P = P
|
||||
problem.RHS = q
|
||||
data = problem.createSyntheticData(mSynth, std=0.05)
|
||||
|
||||
u = problem.field(mSynth)
|
||||
u = problem.reshapeFields(u)
|
||||
mesh.plotImage(u[:,10])
|
||||
M.plotImage(u[:,10])
|
||||
# plt.show()
|
||||
|
||||
# Now set up the problem to do some minimization
|
||||
# problem.dobs = dobs
|
||||
# problem.std = dobs*0 + 0.05
|
||||
m0 = mesh.gridCC[:,0]*0+sig2
|
||||
m0 = M.gridCC[:,0]*0+sig2
|
||||
|
||||
opt = inverse.InexactGaussNewton(maxIterLS=20, maxIter=10, tolF=1e-6, tolX=1e-6, tolG=1e-6, maxIterCG=6)
|
||||
reg = inverse.Regularization(mesh)
|
||||
opt = inverse.InexactGaussNewton(maxIterLS=20, maxIter=3, tolF=1e-6, tolX=1e-6, tolG=1e-6, maxIterCG=6)
|
||||
reg = inverse.Regularization(M)
|
||||
inv = inverse.Inversion(problem, reg, opt, data, beta0=1e4)
|
||||
|
||||
# Check Derivative
|
||||
derChk = lambda m: [inv.dataObj(m), inv.dataObjDeriv(m)]
|
||||
checkDerivative(derChk, mSynth)
|
||||
tests.checkDerivative(derChk, mSynth)
|
||||
|
||||
|
||||
|
||||
@@ -230,7 +220,7 @@ if __name__ == '__main__':
|
||||
|
||||
m = inv.run(m0)
|
||||
|
||||
plt.colorbar(mesh.plotImage(m))
|
||||
plt.colorbar(M.plotImage(m))
|
||||
print m
|
||||
plt.show()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from SimPEG import utils, data, np, sp
|
||||
from SimPEG import utils, np, sp
|
||||
import Data
|
||||
norm = np.linalg.norm
|
||||
|
||||
|
||||
@@ -224,7 +225,7 @@ class Problem(object):
|
||||
noise = std*abs(dtrue)*np.random.randn(*dtrue.shape)
|
||||
dobs = dtrue+noise
|
||||
stdev = dobs*0 + std
|
||||
return data.SimPEGData(self, dobs=dobs, std=stdev, dtrue=dtrue, mtrue=m)
|
||||
return Data.SimPEGData(self, dobs=dobs, std=stdev, dtrue=dtrue, mtrue=m)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from Problem import *
|
||||
import ModelTransforms
|
||||
from Data import *
|
||||
|
||||
@@ -1,2 +1,15 @@
|
||||
import TestUtils
|
||||
from TestUtils import checkDerivative, Rosenbrock, OrderTest, getQuadratic
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
import glob
|
||||
import unittest
|
||||
test_file_strings = glob.glob('test_*.py')
|
||||
module_strings = [str[0:len(str)-3] for str in test_file_strings]
|
||||
suites = [unittest.defaultTestLoader.loadTestsFromName(str) for str
|
||||
in module_strings]
|
||||
testSuite = unittest.TestSuite(suites)
|
||||
|
||||
unittest.TextTestRunner(verbosity=2).run(testSuite)
|
||||
|
||||
Reference in New Issue
Block a user