From 4af67b8da4dd8e712b6e1c4958aed9de49ea7133 Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Sun, 18 May 2014 19:50:55 -0700 Subject: [PATCH] more fields updates and tests --- SimPEG/Survey.py | 16 +++++++++-- SimPEG/Tests/test_Survey.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 5fd3318c..29fbafca 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -357,10 +357,17 @@ class Fields(object): else: # Aliased fields alias, loc, func = self.aliasFields[name] + + txII = np.array(self.survey.txList)[ind] + if isinstance(txII, np.ndarray): + txII = txII.tolist() + if len(txII) == 1: + txII = txII[0] + if type(func) is str: assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' func = getattr(self, func) - out = func(self._fields[alias][:,ind], ind) + out = func(self._fields[alias][:,ind], txII) if out.shape[1] == 1: out = Utils.mkvc(out) @@ -446,7 +453,12 @@ class TimeFields(Fields): pointerFields = pointerFields.reshape(pointerShape, order='F') timeII = np.arange(self.survey.prob.nT + 1)[timeInd] - txII = list(np.array(self.survey.txList)[txInd]) + txII = np.array(self.survey.txList)[txInd] + if isinstance(txII, np.ndarray): + txII = txII.tolist() + if len(txII) == 1: + txII = txII[0] + if timeII.size == 1: pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F') diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py index 1657e7db..6c702252 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Survey.py @@ -164,6 +164,26 @@ class FieldsTest_Alias(unittest.TestCase): F[self.Tx0, 'b'] = F[self.Tx0, 'b'] self.assertRaises(KeyError, f) # can't set a alias attr. + def test_aliasFunction(self): + def alias(e, ind): + self.assertTrue(ind is self.Tx0) + return self.F.mesh.edgeCurl * e + F = Survey.Fields(self.F.mesh, self.F.survey, knownFields={'e':'E'}, aliasFields={'b':['e','F',alias]}) + e = np.random.rand(F.mesh.nE,1) + F[self.Tx0, 'e'] = e + F[self.Tx0, 'b'] + + + def alias(e, ind): + self.assertTrue(type(ind) is list) + self.assertTrue(ind[0] is self.Tx0) + self.assertTrue(ind[1] is self.Tx1) + return self.F.mesh.edgeCurl * e + F = Survey.Fields(self.F.mesh, self.F.survey, knownFields={'e':'E'}, aliasFields={'b':['e','F',alias]}) + e = np.random.rand(F.mesh.nE,2) + F[[self.Tx0, self.Tx1], 'e'] = e + F[[self.Tx0, self.Tx1], 'b'] + class FieldsTest_Time(unittest.TestCase): @@ -337,8 +357,43 @@ class FieldsTest_Time_Aliased(unittest.TestCase): F[self.Tx0, 'e'] = F[self.Tx0, 'e'] self.assertRaises(KeyError, f) # can't set a alias attr. + def test_aliasFunction(self): + nT = self.F.survey.prob.nT + 1 + count = [0] + def alias(e, txInd, timeInd): + count[0] += 1 + self.assertTrue(txInd is self.Tx0) + return self.F.mesh.edgeCurl * e + F = Survey.TimeFields(self.F.mesh, self.F.survey, knownFields={'e':'E'}, aliasFields={'b':['e','F',alias]}) + e = np.random.rand(F.mesh.nE,1,nT) + F[self.Tx0, 'e', :] = e + F[self.Tx0, 'b', :] + self.assertTrue(count[0] == nT) # ensure that this is called for every time separately. + e = np.random.rand(F.mesh.nE,1,1) + F[self.Tx0, 'e', 1] = e + count[0] = 0 + F[self.Tx0, 'b', 1] + self.assertTrue(count[0] == 1) # ensure that this is called only once. + def alias(e, txInd, timeInd): + count[0] += 1 + self.assertTrue(type(txInd) is list) + self.assertTrue(txInd[0] is self.Tx0) + self.assertTrue(txInd[1] is self.Tx1) + return self.F.mesh.edgeCurl * e + F = Survey.TimeFields(self.F.mesh, self.F.survey, knownFields={'e':'E'}, aliasFields={'b':['e','F',alias]}) + e = np.random.rand(F.mesh.nE,2, nT) + F[[self.Tx0, self.Tx1], 'e', :] = e + count[0] = 0 + F[[self.Tx0, self.Tx1], 'b', :] + self.assertTrue(count[0] == nT) # ensure that this is called for every time separately. + e = np.random.rand(F.mesh.nE,2, 1) + F[[self.Tx0, self.Tx1], 'e', 1] = e + count[0] = 0 + F[[self.Tx0, self.Tx1], 'b', 1] + self.assertTrue(count[0] == 1) # ensure that this is called only once. + if __name__ == '__main__': unittest.main()