Aliased fields checks.

This commit is contained in:
rowanc1
2014-05-15 09:50:34 -07:00
parent 93f010bba9
commit 6cc41319a0
2 changed files with 73 additions and 8 deletions
+14 -6
View File
@@ -234,6 +234,9 @@ class Fields(object):
if self.aliasFields is None:
self.aliasFields = {}
allFields = [k for k in self.knownFields] + [a for a in self.aliasFields]
assert len(allFields) == len(set(allFields)), 'Aliased fields and Known Fields have overlapping definitions.'
def _storageShape(self, nP):
nTx = self.survey.nTx
return (nP, nTx)
@@ -338,16 +341,14 @@ class Fields(object):
if name in self._fields:
out = self._fields[name][:,ind]
else:
out = self._getAliasField(name, ind)
# Aliased fields
alias, func = self.aliasFields[name]
out = func(self, self._fields[alias][:,ind], ind)
if out.shape[1] == 1:
out = Utils.mkvc(out)
return out
def _getAliasField(self, name, ind):
alias, func = self.aliasFields[name]
return func(self, self._fields[alias][:,ind], ind)
def __contains__(self, other):
if other in self.aliasFields:
other = self.aliasFields[other][0]
@@ -391,7 +392,14 @@ class TimeFields(Fields):
def _getField(self, name, ind):
txInd, timeInd = ind
out = self._fields[name][:,txInd,timeInd]
if name in self._fields:
out = self._fields[name][:,txInd,timeInd]
else:
# Aliased fields
alias, func = self.aliasFields[name]
out = func(self, self._fields[name][:,txInd,timeInd], txInd, timeInd)
if out.shape[1] == 1:
if out.ndim == 2:
out = out[:,0]
+59 -2
View File
@@ -26,6 +26,11 @@ class DataAndFieldsTest(unittest.TestCase):
self.mesh = mesh
self.XYZ = XYZ
def test_overlappingFields(self):
self.assertRaises(AssertionError, Survey.Fields, self.F.mesh, self.F.survey,
knownFields={'b':'F'},
aliasFields={'b':['b',(lambda F, b, ind: b)]})
def test_data(self):
V = []
for tx in self.D.survey.txList:
@@ -44,7 +49,7 @@ class DataAndFieldsTest(unittest.TestCase):
F = self.F
nTx = F.survey.nTx
self.assertTrue('b' not in F)
self.assertTrue('b' not in F)
self.assertTrue('e' not in F)
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
self.assertTrue('b' not in F)
@@ -134,7 +139,7 @@ class FieldsTest_Alias(unittest.TestCase):
F = self.F
nTx = F.survey.nTx
self.assertTrue('b' not in F)
self.assertTrue('b' not in F)
self.assertTrue('e' not in F)
e = np.random.rand(F.mesh.nE, nTx)
F[:, 'e'] = e
self.assertTrue('b' in F)
@@ -182,6 +187,19 @@ class FieldsTest_Time(unittest.TestCase):
self.mesh = mesh
self.XYZ = XYZ
def test_contains(self):
F = self.F
nTx = F.survey.nTx
nT = F.survey.prob.nT + 1
self.assertTrue('b' not in F)
self.assertTrue('e' not in F)
self.assertTrue('phi' not in F)
e = np.random.rand(F.mesh.nE, nTx, nT)
F[:, 'e', :] = e
self.assertTrue('e' in F)
self.assertTrue('b' not in F)
self.assertTrue('phi' not in F)
def test_SetGet(self):
F = self.F
nTx = F.survey.nTx
@@ -240,5 +258,44 @@ class FieldsTest_Time(unittest.TestCase):
def fun(): self.F[freq,'notThere']
self.assertRaises(KeyError, fun)
class FieldsTest_Time_Aliased(unittest.TestCase):
def setUp(self):
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
x = np.linspace(5,10,3)
XYZ = Utils.ndgrid(x,x,np.r_[0.])
txLoc = np.r_[0,0,0.]
rxList0 = Survey.BaseRx(XYZ, 'exi')
Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0])
rxList1 = Survey.BaseRx(XYZ, 'bxi')
Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1])
rxList2 = Survey.BaseRx(XYZ, 'bxi')
Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2])
rxList3 = Survey.BaseRx(XYZ, 'bxi')
Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3])
Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3])
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
survey = Survey.BaseSurvey(txList=txList)
prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)])
survey.pair(prob)
self.F = Survey.TimeFields(mesh, survey, knownFields={'b':'F'}, aliasFields={'e':['b',(lambda F, b, ind: F.mesh.edgeCurl.T * b)]})
self.Tx0 = Tx0
self.Tx1 = Tx1
self.mesh = mesh
self.XYZ = XYZ
def test_contains(self):
F = self.F
nTx = F.survey.nTx
nT = F.survey.prob.nT + 1
self.assertTrue('b' not in F)
self.assertTrue('e' not in F)
b = np.random.rand(F.mesh.nF, nTx, nT)
F[:, 'b', :] = b
self.assertTrue('e' in F)
self.assertTrue('b' in F)
if __name__ == '__main__':
unittest.main()