mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 21:08:35 +08:00
Fancy Fields Storage in SimPEG
This commit is contained in:
@@ -190,6 +190,106 @@ class Data(object):
|
||||
indBot += rx.nD
|
||||
|
||||
|
||||
class Fields(object):
|
||||
"""Fancy Field Storage
|
||||
|
||||
u[:,'phi'] = phi
|
||||
print u[tx0,'phi']
|
||||
|
||||
"""
|
||||
|
||||
knownFields = None
|
||||
txPair = BaseTx
|
||||
|
||||
def __init__(self, mesh, survey, **kwargs):
|
||||
self.survey = survey
|
||||
self.mesh = mesh
|
||||
Utils.setKwargs(self, **kwargs)
|
||||
self._fields = {}
|
||||
|
||||
def _initStore(self, name):
|
||||
if name in self._fields:
|
||||
return self._fields[name]
|
||||
|
||||
assert name in self.knownFields, 'field name is not known.'
|
||||
|
||||
loc = self.knownFields[name]
|
||||
|
||||
nP = {'CC': self.mesh.nC,
|
||||
'F': self.mesh.nF,
|
||||
'E': self.mesh.nE}[loc]
|
||||
|
||||
nTx = self.survey.nTx
|
||||
field = np.empty((nP, nTx))
|
||||
|
||||
self._fields[name] = field
|
||||
|
||||
return field
|
||||
|
||||
def _indexAndNameFromKey(self, key):
|
||||
if type(key) is not tuple:
|
||||
key = (key,)
|
||||
if len(key) == 1:
|
||||
key += (None,)
|
||||
|
||||
assert len(key) == 2, 'must be [Tx, fieldName]'
|
||||
|
||||
txTestList, name = key
|
||||
|
||||
if name is not None and name not in self.knownFields:
|
||||
raise KeyError('Invalid field name')
|
||||
|
||||
if type(txTestList) is slice:
|
||||
ind = txTestList
|
||||
else:
|
||||
if type(txTestList) is not list:
|
||||
txTestList = [txTestList]
|
||||
for txTest in txTestList:
|
||||
if not isinstance(txTest, self.txPair):
|
||||
raise KeyError('First index must be a Transmitter')
|
||||
if txTest not in self.survey.txList:
|
||||
raise KeyError('Invalid Transmitter, not in survey list.')
|
||||
|
||||
ind = np.in1d(self.survey.txList, txTestList)
|
||||
|
||||
return ind, name
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
ind, name = self._indexAndNameFromKey(key)
|
||||
if name is None:
|
||||
freq = key
|
||||
assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.'
|
||||
newFields = value
|
||||
elif name in self.knownFields:
|
||||
assert type(value) is np.ndarray, 'Must be set to a numpy array'
|
||||
newFields = {name: value}
|
||||
else:
|
||||
raise Exception('Unknown setter')
|
||||
|
||||
for name in newFields:
|
||||
field = self._initStore(name)
|
||||
NEWF = newFields[name]
|
||||
if field.shape[1] == 1 or NEWF.ndim == 1:
|
||||
NEWF = Utils.mkvc(NEWF,2)
|
||||
field[:,ind] = NEWF
|
||||
|
||||
def __getitem__(self, key):
|
||||
ind, name = self._indexAndNameFromKey(key)
|
||||
if name is None:
|
||||
out = {}
|
||||
for name in self._fields:
|
||||
out[name] = self._fields[name][:,ind]
|
||||
if out[name].shape[1] == 1:
|
||||
out[name] = Utils.mkvc(out[name])
|
||||
return out
|
||||
|
||||
out = self._fields[name][:,ind]
|
||||
if out.shape[1] == 1:
|
||||
out = Utils.mkvc(out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class BaseSurvey(object):
|
||||
"""Survey holds the observed data, and the standard deviations."""
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from SimPEG import *
|
||||
|
||||
class DataTest(unittest.TestCase):
|
||||
class DataAndFieldsTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
|
||||
@@ -20,6 +20,7 @@ class DataTest(unittest.TestCase):
|
||||
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
|
||||
survey = Survey.BaseSurvey(txList=txList)
|
||||
self.D = Survey.Data(survey)
|
||||
self.F = Survey.Fields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'})
|
||||
self.Tx0 = Tx0
|
||||
self.Tx1 = Tx1
|
||||
self.mesh = mesh
|
||||
@@ -45,7 +46,54 @@ class DataTest(unittest.TestCase):
|
||||
txs += [txs[0]]
|
||||
self.assertRaises(AssertionError, Survey.BaseSurvey, txList=txs)
|
||||
|
||||
def test_SetGet(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
b = np.random.rand(F.mesh.nF, nTx)
|
||||
F[:, 'b'] = b
|
||||
|
||||
self.assertTrue(np.all(F[:, 'e'] == e))
|
||||
self.assertTrue(np.all(F[:, 'b'] == b))
|
||||
F[:] = {'b':b,'e':e}
|
||||
self.assertTrue(np.all(F[:, 'e'] == e))
|
||||
self.assertTrue(np.all(F[:, 'b'] == b))
|
||||
|
||||
b = np.random.rand(F.mesh.nF,1)
|
||||
F[self.Tx0, 'b'] = b
|
||||
self.assertTrue(np.all(F[self.Tx0, 'b'] == Utils.mkvc(b)))
|
||||
|
||||
b = np.random.rand(F.mesh.nF)
|
||||
F[self.Tx0, 'b'] = b
|
||||
self.assertTrue(np.all(F[self.Tx0, 'b'] == b))
|
||||
|
||||
phi = np.random.rand(F.mesh.nC,2)
|
||||
F[[self.Tx0,self.Tx1], 'phi'] = phi
|
||||
self.assertTrue(np.all(F[[self.Tx0,self.Tx1], 'phi'] == phi))
|
||||
|
||||
fdict = F[:]
|
||||
self.assertTrue(type(fdict) is dict)
|
||||
self.assertTrue(sorted([k for k in fdict]) == ['b','e','phi'])
|
||||
|
||||
b = np.random.rand(F.mesh.nF, 2)
|
||||
F[[self.Tx0, self.Tx1],'b'] = b
|
||||
self.assertTrue(F[self.Tx0]['b'].shape == (F.mesh.nF,))
|
||||
self.assertTrue(F[self.Tx0,'b'].shape == (F.mesh.nF,))
|
||||
self.assertTrue(np.all(F[self.Tx0,'b'] == b[:,0]))
|
||||
self.assertTrue(np.all(F[self.Tx1,'b'] == b[:,1]))
|
||||
|
||||
def test_assertions(self):
|
||||
freq = [self.Tx0, self.Tx1]
|
||||
bWrongSize = np.random.rand(self.F.mesh.nE, self.F.survey.nTx)
|
||||
def fun(): self.F[freq, 'b'] = bWrongSize
|
||||
self.assertRaises(ValueError, fun)
|
||||
def fun(): self.F[-999.]
|
||||
self.assertRaises(KeyError, fun)
|
||||
def fun(): self.F['notRight']
|
||||
self.assertRaises(KeyError, fun)
|
||||
def fun(): self.F[freq,'notThere']
|
||||
self.assertRaises(KeyError, fun)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user