mirror of
https://github.com/wassname/simpeg.git
synced 2026-07-06 05:16:51 +08:00
Aliased fields checks.
This commit is contained in:
+14
-6
@@ -234,6 +234,9 @@ class Fields(object):
|
||||
if self.aliasFields is None:
|
||||
self.aliasFields = {}
|
||||
|
||||
allFields = [k for k in self.knownFields] + [a for a in self.aliasFields]
|
||||
assert len(allFields) == len(set(allFields)), 'Aliased fields and Known Fields have overlapping definitions.'
|
||||
|
||||
def _storageShape(self, nP):
|
||||
nTx = self.survey.nTx
|
||||
return (nP, nTx)
|
||||
@@ -338,16 +341,14 @@ class Fields(object):
|
||||
if name in self._fields:
|
||||
out = self._fields[name][:,ind]
|
||||
else:
|
||||
out = self._getAliasField(name, ind)
|
||||
# Aliased fields
|
||||
alias, func = self.aliasFields[name]
|
||||
out = func(self, self._fields[alias][:,ind], ind)
|
||||
|
||||
if out.shape[1] == 1:
|
||||
out = Utils.mkvc(out)
|
||||
return out
|
||||
|
||||
def _getAliasField(self, name, ind):
|
||||
alias, func = self.aliasFields[name]
|
||||
return func(self, self._fields[alias][:,ind], ind)
|
||||
|
||||
def __contains__(self, other):
|
||||
if other in self.aliasFields:
|
||||
other = self.aliasFields[other][0]
|
||||
@@ -391,7 +392,14 @@ class TimeFields(Fields):
|
||||
|
||||
def _getField(self, name, ind):
|
||||
txInd, timeInd = ind
|
||||
out = self._fields[name][:,txInd,timeInd]
|
||||
|
||||
if name in self._fields:
|
||||
out = self._fields[name][:,txInd,timeInd]
|
||||
else:
|
||||
# Aliased fields
|
||||
alias, func = self.aliasFields[name]
|
||||
out = func(self, self._fields[name][:,txInd,timeInd], txInd, timeInd)
|
||||
|
||||
if out.shape[1] == 1:
|
||||
if out.ndim == 2:
|
||||
out = out[:,0]
|
||||
|
||||
@@ -26,6 +26,11 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
self.mesh = mesh
|
||||
self.XYZ = XYZ
|
||||
|
||||
def test_overlappingFields(self):
|
||||
self.assertRaises(AssertionError, Survey.Fields, self.F.mesh, self.F.survey,
|
||||
knownFields={'b':'F'},
|
||||
aliasFields={'b':['b',(lambda F, b, ind: b)]})
|
||||
|
||||
def test_data(self):
|
||||
V = []
|
||||
for tx in self.D.survey.txList:
|
||||
@@ -44,7 +49,7 @@ class DataAndFieldsTest(unittest.TestCase):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('e' not in F)
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
self.assertTrue('b' not in F)
|
||||
@@ -134,7 +139,7 @@ class FieldsTest_Alias(unittest.TestCase):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('e' not in F)
|
||||
e = np.random.rand(F.mesh.nE, nTx)
|
||||
F[:, 'e'] = e
|
||||
self.assertTrue('b' in F)
|
||||
@@ -182,6 +187,19 @@ class FieldsTest_Time(unittest.TestCase):
|
||||
self.mesh = mesh
|
||||
self.XYZ = XYZ
|
||||
|
||||
def test_contains(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
nT = F.survey.prob.nT + 1
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('e' not in F)
|
||||
self.assertTrue('phi' not in F)
|
||||
e = np.random.rand(F.mesh.nE, nTx, nT)
|
||||
F[:, 'e', :] = e
|
||||
self.assertTrue('e' in F)
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('phi' not in F)
|
||||
|
||||
def test_SetGet(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
@@ -240,5 +258,44 @@ class FieldsTest_Time(unittest.TestCase):
|
||||
def fun(): self.F[freq,'notThere']
|
||||
self.assertRaises(KeyError, fun)
|
||||
|
||||
|
||||
class FieldsTest_Time_Aliased(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30])
|
||||
x = np.linspace(5,10,3)
|
||||
XYZ = Utils.ndgrid(x,x,np.r_[0.])
|
||||
txLoc = np.r_[0,0,0.]
|
||||
rxList0 = Survey.BaseRx(XYZ, 'exi')
|
||||
Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0])
|
||||
rxList1 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1])
|
||||
rxList2 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2])
|
||||
rxList3 = Survey.BaseRx(XYZ, 'bxi')
|
||||
Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3])
|
||||
Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3])
|
||||
txList = [Tx0,Tx1,Tx2,Tx3,Tx4]
|
||||
survey = Survey.BaseSurvey(txList=txList)
|
||||
prob = Problem.BaseTimeProblem(mesh, timeSteps=[(10.,3), (20.,2)])
|
||||
survey.pair(prob)
|
||||
self.F = Survey.TimeFields(mesh, survey, knownFields={'b':'F'}, aliasFields={'e':['b',(lambda F, b, ind: F.mesh.edgeCurl.T * b)]})
|
||||
self.Tx0 = Tx0
|
||||
self.Tx1 = Tx1
|
||||
self.mesh = mesh
|
||||
self.XYZ = XYZ
|
||||
|
||||
def test_contains(self):
|
||||
F = self.F
|
||||
nTx = F.survey.nTx
|
||||
nT = F.survey.prob.nT + 1
|
||||
self.assertTrue('b' not in F)
|
||||
self.assertTrue('e' not in F)
|
||||
b = np.random.rand(F.mesh.nF, nTx, nT)
|
||||
F[:, 'b', :] = b
|
||||
self.assertTrue('e' in F)
|
||||
self.assertTrue('b' in F)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user