Fancy Fields Storage in SimPEG

This commit is contained in:
rowanc1
2014-04-26 16:30:34 -07:00
parent e77d85584f
commit cb2c2aec3b
2 changed files with 149 additions and 1 deletions
+100
View File
@@ -190,6 +190,106 @@ class Data(object):
indBot += rx.nD
class Fields(object):
"""Fancy Field Storage
u[:,'phi'] = phi
print u[tx0,'phi']
"""
knownFields = None
txPair = BaseTx
def __init__(self, mesh, survey, **kwargs):
self.survey = survey
self.mesh = mesh
Utils.setKwargs(self, **kwargs)
self._fields = {}
def _initStore(self, name):
if name in self._fields:
return self._fields[name]
assert name in self.knownFields, 'field name is not known.'
loc = self.knownFields[name]
nP = {'CC': self.mesh.nC,
'F': self.mesh.nF,
'E': self.mesh.nE}[loc]
nTx = self.survey.nTx
field = np.empty((nP, nTx))
self._fields[name] = field
return field
def _indexAndNameFromKey(self, key):
if type(key) is not tuple:
key = (key,)
if len(key) == 1:
key += (None,)
assert len(key) == 2, 'must be [Tx, fieldName]'
txTestList, name = key
if name is not None and name not in self.knownFields:
raise KeyError('Invalid field name')
if type(txTestList) is slice:
ind = txTestList
else:
if type(txTestList) is not list:
txTestList = [txTestList]
for txTest in txTestList:
if not isinstance(txTest, self.txPair):
raise KeyError('First index must be a Transmitter')
if txTest not in self.survey.txList:
raise KeyError('Invalid Transmitter, not in survey list.')
ind = np.in1d(self.survey.txList, txTestList)
return ind, name
def __setitem__(self, key, value):
ind, name = self._indexAndNameFromKey(key)
if name is None:
freq = key
assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.'
newFields = value
elif name in self.knownFields:
assert type(value) is np.ndarray, 'Must be set to a numpy array'
newFields = {name: value}
else:
raise Exception('Unknown setter')
for name in newFields:
field = self._initStore(name)
NEWF = newFields[name]
if field.shape[1] == 1 or NEWF.ndim == 1:
NEWF = Utils.mkvc(NEWF,2)
field[:,ind] = NEWF
def __getitem__(self, key):
ind, name = self._indexAndNameFromKey(key)
if name is None:
out = {}
for name in self._fields:
out[name] = self._fields[name][:,ind]
if out[name].shape[1] == 1:
out[name] = Utils.mkvc(out[name])
return out
out = self._fields[name][:,ind]
if out.shape[1] == 1:
out = Utils.mkvc(out)
return out
class BaseSurvey(object):
"""Survey holds the observed data, and the standard deviations."""
+49 -1
View File
@@ -1,7 +1,7 @@
import unittest
from SimPEG import *
class DataTest(unittest.TestCase):
class DataAndFieldsTest(unittest.TestCase):
def setUp(self):
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
@@ -20,6 +20,7 @@ class DataTest(unittest.TestCase):
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
survey = Survey.BaseSurvey(txList=txList)
self.D = Survey.Data(survey)
self.F = Survey.Fields(mesh, survey, knownFields={'phi':'CC','e':'E','b':'F'})
self.Tx0 = Tx0
self.Tx1 = Tx1
self.mesh = mesh
@@ -45,7 +46,54 @@ class DataTest(unittest.TestCase):
txs += [txs[0]]
self.assertRaises(AssertionError, Survey.BaseSurvey, txList=txs)
def test_SetGet(self):
F = self.F
nTx = F.survey.nTx
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
b = np.random.rand(F.mesh.nF, nTx)
F[:, 'b'] = b
self.assertTrue(np.all(F[:, 'e'] == e))
self.assertTrue(np.all(F[:, 'b'] == b))
F[:] = {'b':b,'e':e}
self.assertTrue(np.all(F[:, 'e'] == e))
self.assertTrue(np.all(F[:, 'b'] == b))
b = np.random.rand(F.mesh.nF,1)
F[self.Tx0, 'b'] = b
self.assertTrue(np.all(F[self.Tx0, 'b'] == Utils.mkvc(b)))
b = np.random.rand(F.mesh.nF)
F[self.Tx0, 'b'] = b
self.assertTrue(np.all(F[self.Tx0, 'b'] == b))
phi = np.random.rand(F.mesh.nC,2)
F[[self.Tx0,self.Tx1], 'phi'] = phi
self.assertTrue(np.all(F[[self.Tx0,self.Tx1], 'phi'] == phi))
fdict = F[:]
self.assertTrue(type(fdict) is dict)
self.assertTrue(sorted([k for k in fdict]) == ['b','e','phi'])
b = np.random.rand(F.mesh.nF, 2)
F[[self.Tx0, self.Tx1],'b'] = b
self.assertTrue(F[self.Tx0]['b'].shape == (F.mesh.nF,))
self.assertTrue(F[self.Tx0,'b'].shape == (F.mesh.nF,))
self.assertTrue(np.all(F[self.Tx0,'b'] == b[:,0]))
self.assertTrue(np.all(F[self.Tx1,'b'] == b[:,1]))
def test_assertions(self):
freq = [self.Tx0, self.Tx1]
bWrongSize = np.random.rand(self.F.mesh.nE, self.F.survey.nTx)
def fun(): self.F[freq, 'b'] = bWrongSize
self.assertRaises(ValueError, fun)
def fun(): self.F[-999.]
self.assertRaises(KeyError, fun)
def fun(): self.F['notRight']
self.assertRaises(KeyError, fun)
def fun(): self.F[freq,'notThere']
self.assertRaises(KeyError, fun)
if __name__ == '__main__':
unittest.main()